diff --git a/docs/source/classification/accuracy.rst b/docs/source/classification/accuracy.rst index ae47d733b9f..9c22eab6b71 100644 --- a/docs/source/classification/accuracy.rst +++ b/docs/source/classification/accuracy.rst @@ -16,41 +16,41 @@ ________________ BinaryAccuracy ^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryAccuracy +.. autoclass:: torchmetrics.classification.BinaryAccuracy :noindex: MulticlassAccuracy ^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassAccuracy +.. autoclass:: torchmetrics.classification.MulticlassAccuracy :noindex: MultilabelAccuracy ^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelAccuracy +.. autoclass:: torchmetrics.classification.MultilabelAccuracy :noindex: Functional Interface ____________________ -.. autofunction:: torchmetrics.functional.accuracy +.. autofunction:: torchmetrics.functional.classification.accuracy :noindex: binary_accuracy ^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_accuracy +.. autofunction:: torchmetrics.functional.classification.binary_accuracy :noindex: multiclass_accuracy ^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_accuracy +.. autofunction:: torchmetrics.functional.classification.multiclass_accuracy :noindex: multilabel_accuracy ^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_accuracy +.. autofunction:: torchmetrics.functional.classification.multilabel_accuracy :noindex: diff --git a/docs/source/classification/auroc.rst b/docs/source/classification/auroc.rst index 4828f605229..99f3ee83862 100644 --- a/docs/source/classification/auroc.rst +++ b/docs/source/classification/auroc.rst @@ -18,19 +18,19 @@ ________________ BinaryAUROC ^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryAUROC +.. autoclass:: torchmetrics.classification.BinaryAUROC :noindex: MulticlassAUROC ^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassAUROC +.. autoclass:: torchmetrics.classification.MulticlassAUROC :noindex: MultilabelAUROC ^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelAUROC +.. autoclass:: torchmetrics.classification.MultilabelAUROC :noindex: Functional Interface @@ -42,17 +42,17 @@ ____________________ binary_auroc ^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_auroc +.. autofunction:: torchmetrics.functional.classification.binary_auroc :noindex: multiclass_auroc ^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_auroc +.. autofunction:: torchmetrics.functional.classification.multiclass_auroc :noindex: multilabel_auroc ^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_auroc +.. autofunction:: torchmetrics.functional.classification.multilabel_auroc :noindex: diff --git a/docs/source/classification/average_precision.rst b/docs/source/classification/average_precision.rst index b241f682a86..f75428832e1 100644 --- a/docs/source/classification/average_precision.rst +++ b/docs/source/classification/average_precision.rst @@ -16,19 +16,19 @@ ________________ BinaryAveragePrecision ^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryAveragePrecision +.. autoclass:: torchmetrics.classification.BinaryAveragePrecision :noindex: MulticlassAveragePrecision ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassAveragePrecision +.. autoclass:: torchmetrics.classification.MulticlassAveragePrecision :noindex: MultilabelAveragePrecision ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelAveragePrecision +.. autoclass:: torchmetrics.classification.MultilabelAveragePrecision :noindex: Functional Interface @@ -40,17 +40,17 @@ ____________________ binary_average_precision ^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_average_precision +.. autofunction:: torchmetrics.functional.classification.binary_average_precision :noindex: multiclass_average_precision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_average_precision +.. autofunction:: torchmetrics.functional.classification.multiclass_average_precision :noindex: multilabel_average_precision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_average_precision +.. autofunction:: torchmetrics.functional.classification.multilabel_average_precision :noindex: diff --git a/docs/source/classification/calibration_error.rst b/docs/source/classification/calibration_error.rst index 78da990346a..baa2e5e0b1b 100644 --- a/docs/source/classification/calibration_error.rst +++ b/docs/source/classification/calibration_error.rst @@ -18,13 +18,13 @@ ________________ BinaryCalibrationError ^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryCalibrationError +.. autoclass:: torchmetrics.classification.BinaryCalibrationError :noindex: MulticlassCalibrationError ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassCalibrationError +.. autoclass:: torchmetrics.classification.MulticlassCalibrationError :noindex: Functional Interface @@ -36,11 +36,11 @@ ____________________ binary_calibration_error ^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_calibration_error +.. autofunction:: torchmetrics.functional.classification.binary_calibration_error :noindex: multiclass_calibration_error ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_calibration_error +.. autofunction:: torchmetrics.functional.classification.multiclass_calibration_error :noindex: diff --git a/docs/source/classification/cohen_kappa.rst b/docs/source/classification/cohen_kappa.rst index e5617796659..e5f924e71d2 100644 --- a/docs/source/classification/cohen_kappa.rst +++ b/docs/source/classification/cohen_kappa.rst @@ -21,14 +21,14 @@ CohenKappa BinaryCohenKappa ^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryCohenKappa +.. autoclass:: torchmetrics.classification.BinaryCohenKappa :noindex: :exclude-members: update, compute MulticlassCohenKappa ^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassCohenKappa +.. autoclass:: torchmetrics.classification.MulticlassCohenKappa :noindex: :exclude-members: update, compute @@ -44,11 +44,11 @@ cohen_kappa binary_cohen_kappa ^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_cohen_kappa +.. autofunction:: torchmetrics.functional.classification.binary_cohen_kappa :noindex: multiclass_cohen_kappa ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_cohen_kappa +.. autofunction:: torchmetrics.functional.classification.multiclass_cohen_kappa :noindex: diff --git a/docs/source/classification/confusion_matrix.rst b/docs/source/classification/confusion_matrix.rst index bde2207e043..78d1edb6250 100644 --- a/docs/source/classification/confusion_matrix.rst +++ b/docs/source/classification/confusion_matrix.rst @@ -21,23 +21,20 @@ ConfusionMatrix BinaryConfusionMatrix ^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryConfusionMatrix +.. autoclass:: torchmetrics.classification.BinaryConfusionMatrix :noindex: - :exclude-members: update, compute MulticlassConfusionMatrix ^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassConfusionMatrix +.. autoclass:: torchmetrics.classification.MulticlassConfusionMatrix :noindex: - :exclude-members: update, compute MultilabelConfusionMatrix ^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelConfusionMatrix +.. autoclass:: torchmetrics.classification.MultilabelConfusionMatrix :noindex: - :exclude-members: update, compute Functional Interface ____________________ @@ -51,17 +48,17 @@ confusion_matrix binary_confusion_matrix ^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_confusion_matrix +.. autofunction:: torchmetrics.functional.classification.binary_confusion_matrix :noindex: multiclass_confusion_matrix ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_confusion_matrix +.. autofunction:: torchmetrics.functional.classification.multiclass_confusion_matrix :noindex: multilabel_confusion_matrix ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_confusion_matrix +.. autofunction:: torchmetrics.functional.classification.multilabel_confusion_matrix :noindex: diff --git a/docs/source/classification/coverage_error.rst b/docs/source/classification/coverage_error.rst index c29f0b06332..16db100c474 100644 --- a/docs/source/classification/coverage_error.rst +++ b/docs/source/classification/coverage_error.rst @@ -13,7 +13,7 @@ ________________ .. autoclass:: torchmetrics.CoverageError :noindex: -.. autoclass:: torchmetrics.MultilabelCoverageError +.. autoclass:: torchmetrics.classification.MultilabelCoverageError :noindex: Functional Interface @@ -22,5 +22,5 @@ ____________________ .. autofunction:: torchmetrics.functional.coverage_error :noindex: -.. autofunction:: torchmetrics.functional.multilabel_coverage_error +.. autofunction:: torchmetrics.functional.classification.multilabel_coverage_error :noindex: diff --git a/docs/source/classification/exact_match.rst b/docs/source/classification/exact_match.rst index 6e7c77d6003..c3a9000d4c5 100644 --- a/docs/source/classification/exact_match.rst +++ b/docs/source/classification/exact_match.rst @@ -13,7 +13,7 @@ ________________ MultilabelExactMatch ^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelExactMatch +.. autoclass:: torchmetrics.classification.MultilabelExactMatch :noindex: Functional Interface @@ -22,5 +22,5 @@ ____________________ multilabel_exact_match ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_exact_match +.. autofunction:: torchmetrics.functional.classification.multilabel_exact_match :noindex: diff --git a/docs/source/classification/f1_score.rst b/docs/source/classification/f1_score.rst index 2ed2894e45c..3bf41a63df5 100644 --- a/docs/source/classification/f1_score.rst +++ b/docs/source/classification/f1_score.rst @@ -19,19 +19,19 @@ F1Score BinaryF1Score ^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryF1Score +.. autoclass:: torchmetrics.classification.BinaryF1Score :noindex: MulticlassF1Score ^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassF1Score +.. autoclass:: torchmetrics.classification.MulticlassF1Score :noindex: MultilabelF1Score ^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelF1Score +.. autoclass:: torchmetrics.classification.MultilabelF1Score :noindex: Functional Interface @@ -46,17 +46,17 @@ f1_score binary_f1_score ^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_f1_score +.. autofunction:: torchmetrics.functional.classification.binary_f1_score :noindex: multiclass_f1_score ^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_f1_score +.. autofunction:: torchmetrics.functional.classification.multiclass_f1_score :noindex: multilabel_f1_score ^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_f1_score +.. autofunction:: torchmetrics.functional.classification.multilabel_f1_score :noindex: diff --git a/docs/source/classification/fbeta_score.rst b/docs/source/classification/fbeta_score.rst index 7966fb617a5..a00eac98905 100644 --- a/docs/source/classification/fbeta_score.rst +++ b/docs/source/classification/fbeta_score.rst @@ -21,19 +21,19 @@ FBetaScore BinaryFBetaScore ^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryFBetaScore +.. autoclass:: torchmetrics.classification.BinaryFBetaScore :noindex: MulticlassFBetaScore ^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassFBetaScore +.. autoclass:: torchmetrics.classification.MulticlassFBetaScore :noindex: MultilabelFBetaScore ^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelFBetaScore +.. autoclass:: torchmetrics.classification.MultilabelFBetaScore :noindex: Functional Interface @@ -48,17 +48,17 @@ fbeta_score binary_fbeta_score ^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_fbeta_score +.. autofunction:: torchmetrics.functional.classification.binary_fbeta_score :noindex: multiclass_fbeta_score ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_fbeta_score +.. autofunction:: torchmetrics.functional.classification.multiclass_fbeta_score :noindex: multilabel_fbeta_score ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_fbeta_score +.. autofunction:: torchmetrics.functional.classification.multilabel_fbeta_score :noindex: diff --git a/docs/source/classification/hamming_distance.rst b/docs/source/classification/hamming_distance.rst index b86c41b2278..29cc56b765d 100644 --- a/docs/source/classification/hamming_distance.rst +++ b/docs/source/classification/hamming_distance.rst @@ -19,19 +19,19 @@ HammingDistance BinaryHammingDistance ^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryHammingDistance +.. autoclass:: torchmetrics.classification.BinaryHammingDistance :noindex: MulticlassHammingDistance ^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassHammingDistance +.. autoclass:: torchmetrics.classification.MulticlassHammingDistance :noindex: MultilabelHammingDistance ^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelHammingDistance +.. autoclass:: torchmetrics.classification.MultilabelHammingDistance :noindex: Functional Interface @@ -46,17 +46,17 @@ hamming_distance binary_hamming_distance ^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_hamming_distance +.. autofunction:: torchmetrics.functional.classification.binary_hamming_distance :noindex: multiclass_hamming_distance ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_hamming_distance +.. autofunction:: torchmetrics.functional.classification.multiclass_hamming_distance :noindex: multilabel_hamming_distance ^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_hamming_distance +.. autofunction:: torchmetrics.functional.classification.multilabel_hamming_distance :noindex: diff --git a/docs/source/classification/hinge_loss.rst b/docs/source/classification/hinge_loss.rst index 4de979946b8..5a0b992f2a3 100644 --- a/docs/source/classification/hinge_loss.rst +++ b/docs/source/classification/hinge_loss.rst @@ -16,13 +16,13 @@ ________________ BinaryHingeLoss ^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryHingeLoss +.. autoclass:: torchmetrics.classification.BinaryHingeLoss :noindex: MulticlassHingeLoss ^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassHingeLoss +.. autoclass:: torchmetrics.classification.MulticlassHingeLoss :noindex: Functional Interface @@ -34,11 +34,11 @@ ____________________ binary_hinge_loss ^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_hinge_loss +.. autofunction:: torchmetrics.functional.classification.binary_hinge_loss :noindex: multiclass_hinge_loss ^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_hinge_loss +.. autofunction:: torchmetrics.functional.classification.multiclass_hinge_loss :noindex: diff --git a/docs/source/classification/jaccard_index.rst b/docs/source/classification/jaccard_index.rst index e08cb5f3ba0..2fb3f3332e0 100644 --- a/docs/source/classification/jaccard_index.rst +++ b/docs/source/classification/jaccard_index.rst @@ -19,21 +19,21 @@ CohenKappa BinaryJaccardIndex ^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryJaccardIndex +.. autoclass:: torchmetrics.classification.BinaryJaccardIndex :noindex: :exclude-members: update, compute MulticlassJaccardIndex ^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassJaccardIndex +.. autoclass:: torchmetrics.classification.MulticlassJaccardIndex :noindex: :exclude-members: update, compute MultilabelJaccardIndex ^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelJaccardIndex +.. autoclass:: torchmetrics.classification.MultilabelJaccardIndex :noindex: :exclude-members: update, compute @@ -50,17 +50,17 @@ jaccard_index binary_jaccard_index ^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_jaccard_index +.. autofunction:: torchmetrics.functional.classification.binary_jaccard_index :noindex: multiclass_jaccard_index ^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_jaccard_index +.. autofunction:: torchmetrics.functional.classification.multiclass_jaccard_index :noindex: multilabel_jaccard_index ^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_jaccard_index +.. autofunction:: torchmetrics.functional.classification.multilabel_jaccard_index :noindex: diff --git a/docs/source/classification/label_ranking_average_precision.rst b/docs/source/classification/label_ranking_average_precision.rst index 5804816600e..32f1b0867b5 100644 --- a/docs/source/classification/label_ranking_average_precision.rst +++ b/docs/source/classification/label_ranking_average_precision.rst @@ -13,7 +13,7 @@ ________________ .. autoclass:: torchmetrics.LabelRankingAveragePrecision :noindex: -.. autoclass:: torchmetrics.MultilabelRankingAveragePrecision +.. autoclass:: torchmetrics.classification.MultilabelRankingAveragePrecision :noindex: @@ -23,5 +23,5 @@ ____________________ .. autofunction:: torchmetrics.functional.label_ranking_average_precision :noindex: -.. autofunction:: torchmetrics.functional.multilabel_ranking_average_precision +.. autofunction:: torchmetrics.functional.classification.multilabel_ranking_average_precision :noindex: diff --git a/docs/source/classification/label_ranking_loss.rst b/docs/source/classification/label_ranking_loss.rst index 4ba1119507e..168b2c80ceb 100644 --- a/docs/source/classification/label_ranking_loss.rst +++ b/docs/source/classification/label_ranking_loss.rst @@ -14,7 +14,7 @@ ________________ :noindex: -.. autoclass:: torchmetrics.MultilabelRankingLoss +.. autoclass:: torchmetrics.classification.MultilabelRankingLoss :noindex: Functional Interface @@ -23,5 +23,5 @@ ____________________ .. autofunction:: torchmetrics.functional.label_ranking_loss :noindex: -.. autofunction:: torchmetrics.functional.multilabel_ranking_loss +.. autofunction:: torchmetrics.functional.classification.multilabel_ranking_loss :noindex: diff --git a/docs/source/classification/matthews_corr_coef.rst b/docs/source/classification/matthews_corr_coef.rst index 569a209c722..7b29686766b 100644 --- a/docs/source/classification/matthews_corr_coef.rst +++ b/docs/source/classification/matthews_corr_coef.rst @@ -21,21 +21,21 @@ MatthewsCorrCoef BinaryMatthewsCorrCoef ^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryMatthewsCorrCoef +.. autoclass:: torchmetrics.classification.BinaryMatthewsCorrCoef :noindex: :exclude-members: update, compute MulticlassMatthewsCorrCoef ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassMatthewsCorrCoef +.. autoclass:: torchmetrics.classification.MulticlassMatthewsCorrCoef :noindex: :exclude-members: update, compute MultilabelMatthewsCorrCoef ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelMatthewsCorrCoef +.. autoclass:: torchmetrics.classification.MultilabelMatthewsCorrCoef :noindex: :exclude-members: update, compute @@ -52,17 +52,17 @@ matthews_corrcoef binary_matthews_corrcoef ^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_matthews_corrcoef +.. autofunction:: torchmetrics.functional.classification.binary_matthews_corrcoef :noindex: multiclass_matthews_corrcoef ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_matthews_corrcoef +.. autofunction:: torchmetrics.functional.classification.multiclass_matthews_corrcoef :noindex: multilabel_matthews_corrcoef ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_matthews_corrcoef +.. autofunction:: torchmetrics.functional.classification.multilabel_matthews_corrcoef :noindex: diff --git a/docs/source/classification/precision.rst b/docs/source/classification/precision.rst index 1acefff94d4..5115c78bcc2 100644 --- a/docs/source/classification/precision.rst +++ b/docs/source/classification/precision.rst @@ -18,19 +18,19 @@ ________________ BinaryPrecision ^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryPrecision +.. autoclass:: torchmetrics.classification.BinaryPrecision :noindex: MulticlassPrecision ^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassPrecision +.. autoclass:: torchmetrics.classification.MulticlassPrecision :noindex: MultilabelPrecision ^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelPrecision +.. autoclass:: torchmetrics.classification.MultilabelPrecision :noindex: Functional Interface @@ -42,17 +42,17 @@ ____________________ binary_precision ^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_precision +.. autofunction:: torchmetrics.functional.classification.binary_precision :noindex: multiclass_precision ^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_precision +.. autofunction:: torchmetrics.functional.classification.multiclass_precision :noindex: multilabel_precision ^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_precision +.. autofunction:: torchmetrics.functional.classification.multilabel_precision :noindex: diff --git a/docs/source/classification/precision_recall_curve.rst b/docs/source/classification/precision_recall_curve.rst index 412470ccd3d..dbe0421ed12 100644 --- a/docs/source/classification/precision_recall_curve.rst +++ b/docs/source/classification/precision_recall_curve.rst @@ -16,19 +16,19 @@ ________________ BinaryPrecisionRecallCurve ^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryPrecisionRecallCurve +.. autoclass:: torchmetrics.classification.BinaryPrecisionRecallCurve :noindex: MulticlassPrecisionRecallCurve ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassPrecisionRecallCurve +.. autoclass:: torchmetrics.classification.MulticlassPrecisionRecallCurve :noindex: MultilabelPrecisionRecallCurve ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelPrecisionRecallCurve +.. autoclass:: torchmetrics.classification.MultilabelPrecisionRecallCurve :noindex: Functional Interface @@ -40,17 +40,17 @@ ____________________ binary_precision_recall_curve ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_precision_recall_curve +.. autofunction:: torchmetrics.functional.classification.binary_precision_recall_curve :noindex: multiclass_precision_recall_curve ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_precision_recall_curve +.. autofunction:: torchmetrics.functional.classification.multiclass_precision_recall_curve :noindex: multilabel_precision_recall_curve ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_precision_recall_curve +.. autofunction:: torchmetrics.functional.classification.multilabel_precision_recall_curve :noindex: diff --git a/docs/source/classification/recall.rst b/docs/source/classification/recall.rst index 0210cddaed0..771fe4ccfb6 100644 --- a/docs/source/classification/recall.rst +++ b/docs/source/classification/recall.rst @@ -16,19 +16,19 @@ ________________ BinaryRecall ^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryRecall +.. autoclass:: torchmetrics.classification.BinaryRecall :noindex: MulticlassRecall ^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassRecall +.. autoclass:: torchmetrics.classification.MulticlassRecall :noindex: MultilabelRecall ^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelRecall +.. autoclass:: torchmetrics.classification.MultilabelRecall :noindex: Functional Interface @@ -40,17 +40,17 @@ ____________________ binary_recall ^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_recall +.. autofunction:: torchmetrics.functional.classification.binary_recall :noindex: multiclass_recall ^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_recall +.. autofunction:: torchmetrics.functional.classification.multiclass_recall :noindex: multilabel_recall ^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_recall +.. autofunction:: torchmetrics.functional.classification.multilabel_recall :noindex: diff --git a/docs/source/classification/recall_at_fixed_precision.rst b/docs/source/classification/recall_at_fixed_precision.rst index f585f2abc7e..7a43aa23064 100644 --- a/docs/source/classification/recall_at_fixed_precision.rst +++ b/docs/source/classification/recall_at_fixed_precision.rst @@ -13,19 +13,19 @@ ________________ BinaryRecallAtFixedPrecision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryRecallAtFixedPrecision +.. autoclass:: torchmetrics.classification.BinaryRecallAtFixedPrecision :noindex: MulticlassRecallAtFixedPrecision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassRecallAtFixedPrecision +.. autoclass:: torchmetrics.classification.MulticlassRecallAtFixedPrecision :noindex: MultilabelRecallAtFixedPrecision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelRecallAtFixedPrecision +.. autoclass:: torchmetrics.classification.MultilabelRecallAtFixedPrecision :noindex: Functional Interface @@ -34,17 +34,17 @@ ____________________ binary_recall_at_fixed_precision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_recall_at_fixed_precision +.. autofunction:: torchmetrics.functional.classification.binary_recall_at_fixed_precision :noindex: multiclass_recall_at_fixed_precision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_recall_at_fixed_precision +.. autofunction:: torchmetrics.functional.classification.multiclass_recall_at_fixed_precision :noindex: multilabel_recall_at_fixed_precision ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_recall_at_fixed_precision +.. autofunction:: torchmetrics.functional.classification.multilabel_recall_at_fixed_precision :noindex: diff --git a/docs/source/classification/roc.rst b/docs/source/classification/roc.rst index 04dae356790..eef65dd845e 100644 --- a/docs/source/classification/roc.rst +++ b/docs/source/classification/roc.rst @@ -16,19 +16,19 @@ ________________ BinaryROC ^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryROC +.. autoclass:: torchmetrics.classification.BinaryROC :noindex: MulticlassROC ^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassROC +.. autoclass:: torchmetrics.classification.MulticlassROC :noindex: MultilabelROC ^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelROC +.. autoclass:: torchmetrics.classification.MultilabelROC :noindex: Functional Interface @@ -40,17 +40,17 @@ ____________________ binary_roc ^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_roc +.. autofunction:: torchmetrics.functional.classification.binary_roc :noindex: multiclass_roc ^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_roc +.. autofunction:: torchmetrics.functional.classification.multiclass_roc :noindex: multilabel_roc ^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_roc +.. autofunction:: torchmetrics.functional.classification.multilabel_roc :noindex: diff --git a/docs/source/classification/specificity.rst b/docs/source/classification/specificity.rst index 00e0bbb8932..19299e0c863 100644 --- a/docs/source/classification/specificity.rst +++ b/docs/source/classification/specificity.rst @@ -16,19 +16,19 @@ ________________ BinarySpecificity ^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinarySpecificity +.. autoclass:: torchmetrics.classification.BinarySpecificity :noindex: MulticlassSpecificity ^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassSpecificity +.. autoclass:: torchmetrics.classification.MulticlassSpecificity :noindex: MultilabelSpecificity ^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelSpecificity +.. autoclass:: torchmetrics.classification.MultilabelSpecificity :noindex: @@ -41,17 +41,17 @@ ____________________ binary_specificity ^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_specificity +.. autofunction:: torchmetrics.functional.classification.binary_specificity :noindex: multiclass_specificity ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_specificity +.. autofunction:: torchmetrics.functional.classification.multiclass_specificity :noindex: multilabel_specificity ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_specificity +.. autofunction:: torchmetrics.functional.classification.multilabel_specificity :noindex: diff --git a/docs/source/classification/stat_scores.rst b/docs/source/classification/stat_scores.rst index 382e6048534..000027dbbea 100644 --- a/docs/source/classification/stat_scores.rst +++ b/docs/source/classification/stat_scores.rst @@ -21,21 +21,21 @@ StatScores BinaryStatScores ^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.BinaryStatScores +.. autoclass:: torchmetrics.classification.BinaryStatScores :noindex: :exclude-members: update, compute MulticlassStatScores ^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MulticlassStatScores +.. autoclass:: torchmetrics.classification.MulticlassStatScores :noindex: :exclude-members: update, compute MultilabelStatScores ^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: torchmetrics.MultilabelStatScores +.. autoclass:: torchmetrics.classification.MultilabelStatScores :noindex: :exclude-members: update, compute @@ -51,17 +51,17 @@ stat_scores binary_stat_scores ^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.binary_stat_scores +.. autofunction:: torchmetrics.functional.classification.binary_stat_scores :noindex: multiclass_stat_scores ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multiclass_stat_scores +.. autofunction:: torchmetrics.functional.classification.multiclass_stat_scores :noindex: multilabel_stat_scores ^^^^^^^^^^^^^^^^^^^^^^ -.. autofunction:: torchmetrics.functional.multilabel_stat_scores +.. autofunction:: torchmetrics.functional.classification.multilabel_stat_scores :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 8e6910e318f..bb77854d816 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -26,25 +26,6 @@ ROC, Accuracy, AveragePrecision, - BinaryAccuracy, - BinaryAUROC, - BinaryAveragePrecision, - BinaryCalibrationError, - BinaryCohenKappa, - BinaryConfusionMatrix, - BinaryF1Score, - BinaryFBetaScore, - BinaryHammingDistance, - BinaryHingeLoss, - BinaryJaccardIndex, - BinaryMatthewsCorrCoef, - BinaryPrecision, - BinaryPrecisionRecallCurve, - BinaryRecall, - BinaryRecallAtFixedPrecision, - BinaryROC, - BinarySpecificity, - BinaryStatScores, BinnedAveragePrecision, BinnedPrecisionRecallCurve, BinnedRecallAtFixedPrecision, @@ -61,45 +42,6 @@ LabelRankingAveragePrecision, LabelRankingLoss, MatthewsCorrCoef, - MulticlassAccuracy, - MulticlassAUROC, - MulticlassAveragePrecision, - MulticlassCalibrationError, - MulticlassCohenKappa, - MulticlassConfusionMatrix, - MulticlassF1Score, - MulticlassFBetaScore, - MulticlassHammingDistance, - MulticlassHingeLoss, - MulticlassJaccardIndex, - MulticlassMatthewsCorrCoef, - MulticlassPrecision, - MulticlassPrecisionRecallCurve, - MulticlassRecall, - MulticlassRecallAtFixedPrecision, - MulticlassROC, - MulticlassSpecificity, - MulticlassStatScores, - MultilabelAccuracy, - MultilabelAUROC, - MultilabelAveragePrecision, - MultilabelConfusionMatrix, - MultilabelCoverageError, - MultilabelExactMatch, - MultilabelF1Score, - MultilabelFBetaScore, - MultilabelHammingDistance, - MultilabelJaccardIndex, - MultilabelMatthewsCorrCoef, - MultilabelPrecision, - MultilabelPrecisionRecallCurve, - MultilabelRankingAveragePrecision, - MultilabelRankingLoss, - MultilabelRecall, - MultilabelRecallAtFixedPrecision, - MultilabelROC, - MultilabelSpecificity, - MultilabelStatScores, Precision, PrecisionRecallCurve, Recall, @@ -169,26 +111,8 @@ __all__ = [ "functional", "Accuracy", - "BinaryAccuracy", - "MulticlassAccuracy", - "MultilabelAccuracy", "AUC", "AUROC", - "BinaryAUROC", - "BinaryAveragePrecision", - "BinaryPrecisionRecallCurve", - "BinaryRecallAtFixedPrecision", - "BinaryROC", - "MultilabelROC", - "MulticlassAUROC", - "MulticlassAveragePrecision", - "MulticlassPrecisionRecallCurve", - "MulticlassRecallAtFixedPrecision", - "MulticlassROC", - "MultilabelAUROC", - "MultilabelAveragePrecision", - "MultilabelPrecisionRecallCurve", - "MultilabelRecallAtFixedPrecision", "AveragePrecision", "BinnedAveragePrecision", "BinnedPrecisionRecallCurve", @@ -201,12 +125,7 @@ "CharErrorRate", "CHRFScore", "CohenKappa", - "BinaryCohenKappa", - "MulticlassCohenKappa", "ConfusionMatrix", - "BinaryConfusionMatrix", - "MulticlassConfusionMatrix", - "MultilabelConfusionMatrix", "CosineSimilarity", "CoverageError", "Dice", @@ -215,31 +134,15 @@ "ExplainedVariance", "ExtendedEditDistance", "F1Score", - "BinaryF1Score", - "MulticlassF1Score", - "MultilabelF1Score", "FBetaScore", - "BinaryFBetaScore", - "MulticlassFBetaScore", - "MultilabelFBetaScore", "HammingDistance", - "BinaryHammingDistance", - "MultilabelHammingDistance", - "MulticlassHammingDistance", "HingeLoss", "JaccardIndex", - "BinaryJaccardIndex", - "MulticlassJaccardIndex", - "MultilabelJaccardIndex", - "MultilabelExactMatch", "KLDivergence", "LabelRankingAveragePrecision", "LabelRankingLoss", "MatchErrorRate", "MatthewsCorrCoef", - "BinaryMatthewsCorrCoef", - "MulticlassMatthewsCorrCoef", - "MultilabelMatthewsCorrCoef", "MaxMetric", "MeanAbsoluteError", "MeanAbsolutePercentageError", @@ -257,16 +160,10 @@ "PermutationInvariantTraining", "Perplexity", "Precision", - "BinaryPrecision", - "MulticlassPrecision", - "MultilabelPrecision", "PrecisionRecallCurve", "PeakSignalNoiseRatio", "R2Score", "Recall", - "BinaryRecall", - "MulticlassRecall", - "MultilabelRecall", "RetrievalFallOut", "RetrievalHitRate", "RetrievalMAP", @@ -285,17 +182,11 @@ "SignalNoiseRatio", "SpearmanCorrCoef", "Specificity", - "BinarySpecificity", - "MulticlassSpecificity", - "MultilabelSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", "SQuAD", "StructuralSimilarityIndexMeasure", "StatScores", - "BinaryStatScores", - "MulticlassStatScores", - "MultilabelStatScores", "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", @@ -304,11 +195,4 @@ "WordErrorRate", "WordInfoLost", "WordInfoPreserved", - "BinaryCalibrationError", - "MulticlassHingeLoss", - "BinaryHingeLoss", - "MulticlassCalibrationError", - "MultilabelCoverageError", - "MultilabelRankingAveragePrecision", - "MultilabelRankingLoss", ] diff --git a/src/torchmetrics/classification/accuracy.py b/src/torchmetrics/classification/accuracy.py index 7a92b06dea5..a13f51a6c72 100644 --- a/src/torchmetrics/classification/accuracy.py +++ b/src/torchmetrics/classification/accuracy.py @@ -15,6 +15,7 @@ import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.functional.classification.accuracy import ( _accuracy_compute, @@ -73,7 +74,7 @@ class BinaryAccuracy(BinaryStatScores): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinaryAccuracy + >>> from torchmetrics.classification import BinaryAccuracy >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryAccuracy() @@ -81,7 +82,7 @@ class BinaryAccuracy(BinaryStatScores): tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics import BinaryAccuracy + >>> from torchmetrics.classification import BinaryAccuracy >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryAccuracy() @@ -89,7 +90,7 @@ class BinaryAccuracy(BinaryStatScores): tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics import BinaryAccuracy + >>> from torchmetrics.classification import BinaryAccuracy >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -168,7 +169,7 @@ class MulticlassAccuracy(MulticlassStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassAccuracy + >>> from torchmetrics.classification import MulticlassAccuracy >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassAccuracy(num_classes=3) @@ -179,7 +180,7 @@ class MulticlassAccuracy(MulticlassStatScores): tensor([0.5000, 1.0000, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassAccuracy + >>> from torchmetrics.classification import MulticlassAccuracy >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -195,7 +196,7 @@ class MulticlassAccuracy(MulticlassStatScores): tensor([0.5000, 1.0000, 1.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassAccuracy + >>> from torchmetrics.classification import MulticlassAccuracy >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise') @@ -271,7 +272,7 @@ class MultilabelAccuracy(MultilabelStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelAccuracy + >>> from torchmetrics.classification import MultilabelAccuracy >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelAccuracy(num_labels=3) @@ -282,7 +283,7 @@ class MultilabelAccuracy(MultilabelStatScores): tensor([1.0000, 0.5000, 0.5000]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelAccuracy + >>> from torchmetrics.classification import MultilabelAccuracy >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelAccuracy(num_labels=3) @@ -293,7 +294,7 @@ class MultilabelAccuracy(MultilabelStatScores): tensor([1.0000, 0.5000, 0.5000]) Example (multidim tensors): - >>> from torchmetrics import MultilabelAccuracy + >>> from torchmetrics.classification import MultilabelAccuracy >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -456,6 +457,37 @@ class Accuracy(StatScores): correct: Tensor total: Tensor + def __new__( + cls, + threshold: float = 0.5, + num_classes: Optional[int] = None, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + subset_accuracy: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryAccuracy(threshold, **kwargs) + if task == "multiclass": + return MulticlassAccuracy(num_classes, average, top_k, **kwargs) + if task == "multilabel": + return MultilabelAccuracy(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, threshold: float = 0.5, @@ -466,6 +498,10 @@ def __init__( top_k: Optional[int] = None, multiclass: Optional[bool] = None, subset_accuracy: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index 7a0807de349..5365e46dbec 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -83,7 +83,7 @@ class BinaryAUROC(BinaryPrecisionRecallCurve): A single scalar with the auroc score Example: - >>> from torchmetrics import BinaryAUROC + >>> from torchmetrics.classification import BinaryAUROC >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryAUROC(thresholds=None) @@ -169,7 +169,7 @@ class MulticlassAUROC(MulticlassPrecisionRecallCurve): If `average="macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics import MulticlassAUROC + >>> from torchmetrics.classification import MulticlassAUROC >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -271,7 +271,7 @@ class MultilabelAUROC(MultilabelPrecisionRecallCurve): If `average="micro|macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics import MultilabelAUROC + >>> from torchmetrics.classification import MultilabelAUROC >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -404,12 +404,43 @@ class AUROC(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = "macro", + max_fpr: Optional[float] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAUROC(max_fpr, **kwargs) + if task == "multiclass": + return MulticlassAUROC(num_classes, average, **kwargs) + if task == "multilabel": + return MultilabelAUROC(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = "macro", max_fpr: Optional[float] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index 86a59e026fa..53d5bffb0e9 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -84,7 +84,7 @@ class BinaryAveragePrecision(BinaryPrecisionRecallCurve): A single scalar with the average precision score Example: - >>> from torchmetrics import BinaryAveragePrecision + >>> from torchmetrics.classification import BinaryAveragePrecision >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryAveragePrecision(thresholds=None) @@ -162,7 +162,7 @@ class MulticlassAveragePrecision(MulticlassPrecisionRecallCurve): If `average="macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics import MulticlassAveragePrecision + >>> from torchmetrics.classification import MulticlassAveragePrecision >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -269,7 +269,7 @@ class MultilabelAveragePrecision(MultilabelPrecisionRecallCurve): If `average="micro|macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics import MultilabelAveragePrecision + >>> from torchmetrics.classification import MultilabelAveragePrecision >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -384,11 +384,41 @@ class AveragePrecision(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + average: Optional[str] = "macro", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryAveragePrecision(**kwargs) + if task == "multiclass": + return MulticlassAveragePrecision(num_classes, average, **kwargs) + if task == "multilabel": + return MultilabelAveragePrecision(num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = "macro", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) diff --git a/src/torchmetrics/classification/binned_precision_recall.py b/src/torchmetrics/classification/binned_precision_recall.py index 21be67c0475..d7253527ae3 100644 --- a/src/torchmetrics/classification/binned_precision_recall.py +++ b/src/torchmetrics/classification/binned_precision_recall.py @@ -19,6 +19,7 @@ from torchmetrics.functional.classification.average_precision import _average_precision_compute_with_precision_recall from torchmetrics.metric import Metric from torchmetrics.utilities.data import METRIC_EPS, to_onehot +from torchmetrics.utilities.prints import rank_zero_warn def _recall_at_precision( @@ -49,6 +50,10 @@ class BinnedPrecisionRecallCurve(Metric): Computation is performed in constant-memory by computing precision and recall for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + .. warn: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Instead use `PrecisionRecallCurve` metric with the `thresholds` argument set accordingly. + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -122,6 +127,11 @@ def __init__( thresholds: Union[int, Tensor, List[float]] = 100, **kwargs: Any, ) -> None: + rank_zero_warn( + "Metric `BinnedPrecisionRecallCurve` has been deprecated in v0.10 and will be completly removed in v0.11." + " Instead, use the refactored version of `PrecisionRecallCurve` by specifying the `thresholds` argument.", + DeprecationWarning, + ) super().__init__(**kwargs) self.num_classes = num_classes @@ -187,6 +197,10 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): Computation is performed in constant-memory by computing precision and recall for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + .. warn: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Instead use `AveragePrecision` metric with the `thresholds` argument set accordingly. + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -225,6 +239,19 @@ class BinnedAveragePrecision(BinnedPrecisionRecallCurve): [tensor(1.0000), tensor(1.0000), tensor(0.2500), tensor(0.2500), tensor(-0.)] """ + def __init__( + self, + num_classes: int, + thresholds: Union[int, Tensor, List[float]] = 100, + **kwargs: Any, + ) -> None: + rank_zero_warn( + "Metric `BinnedAveragePrecision` has been deprecated in v0.10 and will be completly removed in v0.11." + " Instead, use the refactored version of `AveragePrecision` by specifying the `thresholds` argument.", + DeprecationWarning, + ) + super().__init__(num_classes=num_classes, thresholds=thresholds, **kwargs) + def compute(self) -> Union[List[Tensor], Tensor]: # type: ignore precisions, recalls, _ = super().compute() return _average_precision_compute_with_precision_recall(precisions, recalls, self.num_classes, average=None) @@ -236,6 +263,10 @@ class BinnedRecallAtFixedPrecision(BinnedPrecisionRecallCurve): Computation is performed in constant-memory by computing precision and recall for ``thresholds`` buckets/thresholds (evenly distributed between 0 and 1). + .. warn: + This metric has been deprecated in v0.10 and will be removed in v0.11. + Instead use `RecallAtFixedPrecision` metric with the `thresholds` argument set accordingly. + Forward accepts - ``preds`` (float tensor): ``(N, ...)`` (binary) or ``(N, C, ...)`` (multiclass) tensor @@ -283,6 +314,11 @@ def __init__( thresholds: Union[int, Tensor, List[float]] = 100, **kwargs: Any, ) -> None: + rank_zero_warn( + "Metric `BinnedRecallAtFixedPrecision` has been deprecated in v0.10 and will be completly removed in v0.11." + " Instead, use the refactored version of `RecallAtFixedPrecision` by specifying the `thresholds` argument.", + DeprecationWarning, + ) super().__init__(num_classes=num_classes, thresholds=thresholds, **kwargs) self.min_precision = min_precision diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 1bbcfcf59b0..3fcf8ace9a2 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -73,7 +73,7 @@ class BinaryCalibrationError(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> from torchmetrics import BinaryCalibrationError + >>> from torchmetrics.classification import BinaryCalibrationError >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') @@ -165,7 +165,7 @@ class MulticlassCalibrationError(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> from torchmetrics import MulticlassCalibrationError + >>> from torchmetrics.classification import MulticlassCalibrationError >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], @@ -265,6 +265,27 @@ class CalibrationError(Metric): confidences: List[Tensor] accuracies: List[Tensor] + def __new__( + cls, + n_bins: int = 15, + norm: str = "l1", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(n_bins=n_bins, norm=norm, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCalibrationError(**kwargs) + if task == "multiclass": + return MulticlassCalibrationError(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, n_bins: int = 15, diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 734b7840a6c..3306c419117 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -64,7 +64,7 @@ class labels. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import BinaryCohenKappa + >>> from torchmetrics.classification import BinaryCohenKappa >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> metric = BinaryCohenKappa() @@ -72,7 +72,7 @@ class labels. tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics import BinaryCohenKappa + >>> from torchmetrics.classification import BinaryCohenKappa >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryCohenKappa() @@ -138,7 +138,7 @@ class labels. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics import MulticlassCohenKappa + >>> from torchmetrics.classification import MulticlassCohenKappa >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassCohenKappa(num_classes=3) @@ -146,7 +146,7 @@ class labels. tensor(0.6364) Example (pred is float tensor): - >>> from torchmetrics import MulticlassCohenKappa + >>> from torchmetrics.classification import MulticlassCohenKappa >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -236,6 +236,27 @@ class labels. full_state_update: bool = False confmat: Tensor + def __new__( + cls, + num_classes: Optional[int] = None, + weights: Optional[str] = None, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryCohenKappa(threshold, **kwargs) + if task == "multiclass": + return MulticlassCohenKappa(num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index aadb174305e..c012dfddc1b 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -67,7 +67,7 @@ class BinaryConfusionMatrix(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import BinaryConfusionMatrix + >>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> metric = BinaryConfusionMatrix() @@ -76,7 +76,7 @@ class BinaryConfusionMatrix(Metric): [1, 1]]) Example (preds is float tensor): - >>> from torchmetrics import BinaryConfusionMatrix + >>> from torchmetrics.classification import BinaryConfusionMatrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryConfusionMatrix() @@ -155,7 +155,7 @@ class MulticlassConfusionMatrix(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics import MulticlassConfusionMatrix + >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassConfusionMatrix(num_classes=3) @@ -165,7 +165,7 @@ class MulticlassConfusionMatrix(Metric): [0, 0, 1]]) Example (pred is float tensor): - >>> from torchmetrics import MulticlassConfusionMatrix + >>> from torchmetrics.classification import MulticlassConfusionMatrix >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -251,7 +251,7 @@ class MultilabelConfusionMatrix(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import MultilabelConfusionMatrix + >>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) @@ -261,7 +261,7 @@ class MultilabelConfusionMatrix(Metric): [[0, 1], [0, 1]]]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelConfusionMatrix + >>> from torchmetrics.classification import MultilabelConfusionMatrix >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelConfusionMatrix(num_labels=3) @@ -390,6 +390,31 @@ class ConfusionMatrix(Metric): full_state_update: bool = False confmat: Tensor + def __new__( + cls, + num_classes: Optional[int] = None, + normalize: Optional[str] = None, + threshold: float = 0.5, + multilabel: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryConfusionMatrix(threshold, **kwargs) + if task == "multiclass": + return MulticlassConfusionMatrix(num_classes, **kwargs) + if task == "multilabel": + return MultilabelConfusionMatrix(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/exact_match.py b/src/torchmetrics/classification/exact_match.py index 0f593db1385..5577b6d9a83 100644 --- a/src/torchmetrics/classification/exact_match.py +++ b/src/torchmetrics/classification/exact_match.py @@ -73,7 +73,7 @@ class MultilabelExactMatch(Metric): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelExactMatch + >>> from torchmetrics.classification import MultilabelExactMatch >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelExactMatch(num_labels=3) @@ -81,7 +81,7 @@ class MultilabelExactMatch(Metric): tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics import MultilabelExactMatch + >>> from torchmetrics.classification import MultilabelExactMatch >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelExactMatch(num_labels=3) @@ -89,7 +89,7 @@ class MultilabelExactMatch(Metric): tensor(0.5000) Example (multidim tensors): - >>> from torchmetrics import MultilabelExactMatch + >>> from torchmetrics.classification import MultilabelExactMatch >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ diff --git a/src/torchmetrics/classification/f_beta.py b/src/torchmetrics/classification/f_beta.py index 468c8e67349..73c3bdc446c 100644 --- a/src/torchmetrics/classification/f_beta.py +++ b/src/torchmetrics/classification/f_beta.py @@ -71,7 +71,7 @@ class BinaryFBetaScore(BinaryStatScores): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinaryFBetaScore + >>> from torchmetrics.classification import BinaryFBetaScore >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryFBetaScore(beta=2.0) @@ -79,7 +79,7 @@ class BinaryFBetaScore(BinaryStatScores): tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics import BinaryFBetaScore + >>> from torchmetrics.classification import BinaryFBetaScore >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryFBetaScore(beta=2.0) @@ -87,7 +87,7 @@ class BinaryFBetaScore(BinaryStatScores): tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics import BinaryFBetaScore + >>> from torchmetrics.classification import BinaryFBetaScore >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -187,7 +187,7 @@ class MulticlassFBetaScore(MulticlassStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassFBetaScore + >>> from torchmetrics.classification import MulticlassFBetaScore >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3) @@ -198,7 +198,7 @@ class MulticlassFBetaScore(MulticlassStatScores): tensor([0.5556, 0.8333, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassFBetaScore + >>> from torchmetrics.classification import MulticlassFBetaScore >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -214,7 +214,7 @@ class MulticlassFBetaScore(MulticlassStatScores): tensor([0.5556, 0.8333, 1.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassFBetaScore + >>> from torchmetrics.classification import MulticlassFBetaScore >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise') @@ -315,7 +315,7 @@ class MultilabelFBetaScore(MultilabelStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelFBetaScore + >>> from torchmetrics.classification import MultilabelFBetaScore >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3) @@ -326,7 +326,7 @@ class MultilabelFBetaScore(MultilabelStatScores): tensor([1.0000, 0.0000, 0.8333]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelFBetaScore + >>> from torchmetrics.classification import MultilabelFBetaScore >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3) @@ -337,7 +337,7 @@ class MultilabelFBetaScore(MultilabelStatScores): tensor([1.0000, 0.0000, 0.8333]) Example (multidim tensors): - >>> from torchmetrics import MultilabelFBetaScore + >>> from torchmetrics.classification import MultilabelFBetaScore >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -424,7 +424,7 @@ class BinaryF1Score(BinaryFBetaScore): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinaryF1Score + >>> from torchmetrics.classification import BinaryF1Score >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryF1Score() @@ -432,7 +432,7 @@ class BinaryF1Score(BinaryFBetaScore): tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics import BinaryF1Score + >>> from torchmetrics.classification import BinaryF1Score >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryF1Score() @@ -440,7 +440,7 @@ class BinaryF1Score(BinaryFBetaScore): tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics import BinaryF1Score + >>> from torchmetrics.classification import BinaryF1Score >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -531,7 +531,7 @@ class MulticlassF1Score(MulticlassFBetaScore): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassF1Score + >>> from torchmetrics.classification import MulticlassF1Score >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassF1Score(num_classes=3) @@ -542,7 +542,7 @@ class MulticlassF1Score(MulticlassFBetaScore): tensor([0.6667, 0.6667, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassF1Score + >>> from torchmetrics.classification import MulticlassF1Score >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -558,7 +558,7 @@ class MulticlassF1Score(MulticlassFBetaScore): tensor([0.6667, 0.6667, 1.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassF1Score + >>> from torchmetrics.classification import MulticlassF1Score >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise') @@ -649,7 +649,7 @@ class MultilabelF1Score(MultilabelFBetaScore): - If ``average=None/'none'``, the shape will be ``(N, C)``` Example (preds is int tensor): - >>> from torchmetrics import MultilabelF1Score + >>> from torchmetrics.classification import MultilabelF1Score >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelF1Score(num_labels=3) @@ -660,7 +660,7 @@ class MultilabelF1Score(MultilabelFBetaScore): tensor([1.0000, 0.0000, 0.6667]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelF1Score + >>> from torchmetrics.classification import MultilabelF1Score >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelF1Score(num_labels=3) @@ -671,7 +671,7 @@ class MultilabelF1Score(MultilabelFBetaScore): tensor([1.0000, 0.0000, 0.6667]) Example (multidim tensors): - >>> from torchmetrics import MultilabelF1Score + >>> from torchmetrics.classification import MultilabelF1Score >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -817,6 +817,37 @@ class FBetaScore(StatScores): """ full_state_update: bool = False + def __new__( + cls, + num_classes: Optional[int] = None, + beta: float = 1.0, + threshold: float = 0.5, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryFBetaScore(beta, threshold, **kwargs) + if task == "multiclass": + return MulticlassFBetaScore(beta, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return MultilabelFBetaScore(beta, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, @@ -948,6 +979,36 @@ class F1Score(FBetaScore): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryF1Score(threshold, **kwargs) + if task == "multiclass": + return MulticlassF1Score(num_classes, average, top_k, **kwargs) + if task == "multilabel": + return MultilabelF1Score(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index 6f726dd1927..d63f0a1bf72 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.hamming import ( @@ -64,7 +65,7 @@ class BinaryHammingDistance(BinaryStatScores): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinaryHammingDistance + >>> from torchmetrics.classification import BinaryHammingDistance >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryHammingDistance() @@ -72,7 +73,7 @@ class BinaryHammingDistance(BinaryStatScores): tensor(0.3333) Example (preds is float tensor): - >>> from torchmetrics import BinaryHammingDistance + >>> from torchmetrics.classification import BinaryHammingDistance >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryHammingDistance() @@ -80,7 +81,7 @@ class BinaryHammingDistance(BinaryStatScores): tensor(0.3333) Example (multidim tensors): - >>> from torchmetrics import BinaryHammingDistance + >>> from torchmetrics.classification import BinaryHammingDistance >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -160,7 +161,7 @@ class MulticlassHammingDistance(MulticlassStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassHammingDistance + >>> from torchmetrics.classification import MulticlassHammingDistance >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassHammingDistance(num_classes=3) @@ -171,7 +172,7 @@ class MulticlassHammingDistance(MulticlassStatScores): tensor([0.5000, 0.0000, 0.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassHammingDistance + >>> from torchmetrics.classification import MulticlassHammingDistance >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -187,7 +188,7 @@ class MulticlassHammingDistance(MulticlassStatScores): tensor([0.5000, 0.0000, 0.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassHammingDistance + >>> from torchmetrics.classification import MulticlassHammingDistance >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise') @@ -265,7 +266,7 @@ class MultilabelHammingDistance(MultilabelStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelHammingDistance + >>> from torchmetrics.classification import MultilabelHammingDistance >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelHammingDistance(num_labels=3) @@ -276,7 +277,7 @@ class MultilabelHammingDistance(MultilabelStatScores): tensor([0.0000, 0.5000, 0.5000]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelHammingDistance + >>> from torchmetrics.classification import MultilabelHammingDistance >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelHammingDistance(num_labels=3) @@ -287,7 +288,7 @@ class MultilabelHammingDistance(MultilabelStatScores): tensor([0.0000, 0.5000, 0.5000]) Example (multidim tensors): - >>> from torchmetrics import MultilabelHammingDistance + >>> from torchmetrics.classification import MultilabelHammingDistance >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -361,6 +362,34 @@ class HammingDistance(Metric): correct: Tensor total: Tensor + def __new__( + cls, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[str] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + top_k: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryHammingDistance(threshold, **kwargs) + if task == "multiclass": + return MulticlassHammingDistance(num_classes, average, top_k, **kwargs) + if task == "multilabel": + return MultilabelHammingDistance(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, threshold: float = 0.5, diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 2faf91860ed..ce99e28e83a 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -63,7 +63,7 @@ class BinaryHingeLoss(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> from torchmetrics import BinaryHingeLoss + >>> from torchmetrics.classification import BinaryHingeLoss >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> metric = BinaryHingeLoss() @@ -143,7 +143,7 @@ class MulticlassHingeLoss(Metric): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: - >>> from torchmetrics import MulticlassHingeLoss + >>> from torchmetrics.classification import MulticlassHingeLoss >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], @@ -281,6 +281,27 @@ class HingeLoss(Metric): measure: Tensor total: Tensor + def __new__( + cls, + squared: bool = False, + multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryHingeLoss(squared, **kwargs) + if task == "multiclass": + return MulticlassHingeLoss(num_classes, squared, multiclass_mode, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, squared: bool = False, diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 0c58d961de3..e17e07e4af5 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -59,7 +59,7 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import BinaryJaccardIndex + >>> from torchmetrics.classification import BinaryJaccardIndex >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> metric = BinaryJaccardIndex() @@ -67,7 +67,7 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics import BinaryJaccardIndex + >>> from torchmetrics.classification import BinaryJaccardIndex >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryJaccardIndex() @@ -127,7 +127,7 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics import MulticlassJaccardIndex + >>> from torchmetrics.classification import MulticlassJaccardIndex >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassJaccardIndex(num_classes=3) @@ -135,7 +135,7 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix): tensor(0.6667) Example (pred is float tensor): - >>> from torchmetrics import MulticlassJaccardIndex + >>> from torchmetrics.classification import MulticlassJaccardIndex >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -207,7 +207,7 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import MultilabelJaccardIndex + >>> from torchmetrics.classification import MultilabelJaccardIndex >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelJaccardIndex(num_labels=3) @@ -215,7 +215,7 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix): tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics import MultilabelJaccardIndex + >>> from torchmetrics.classification import MultilabelJaccardIndex >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelJaccardIndex(num_labels=3) @@ -317,6 +317,32 @@ class JaccardIndex(ConfusionMatrix): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + num_classes: int, + average: Optional[str] = "macro", + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + multilabel: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryJaccardIndex(threshold, **kwargs) + if task == "multiclass": + return MulticlassJaccardIndex(num_classes, average, **kwargs) + if task == "multilabel": + return MultilabelJaccardIndex(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index d9b06ea38f2..b31d24b3e2e 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import ( @@ -54,7 +55,7 @@ class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import BinaryMatthewsCorrCoef + >>> from torchmetrics.classification import BinaryMatthewsCorrCoef >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> metric = BinaryMatthewsCorrCoef() @@ -62,7 +63,7 @@ class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): tensor(0.5774) Example (preds is float tensor): - >>> from torchmetrics import BinaryMatthewsCorrCoef + >>> from torchmetrics.classification import BinaryMatthewsCorrCoef >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryMatthewsCorrCoef() @@ -117,7 +118,7 @@ class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics import MulticlassMatthewsCorrCoef + >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) @@ -125,7 +126,7 @@ class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): tensor(0.7000) Example (pred is float tensor): - >>> from torchmetrics import MulticlassMatthewsCorrCoef + >>> from torchmetrics.classification import MulticlassMatthewsCorrCoef >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -186,7 +187,7 @@ class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import MultilabelMatthewsCorrCoef + >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) @@ -194,7 +195,7 @@ class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): tensor(0.3333) Example (preds is float tensor): - >>> from torchmetrics import MultilabelMatthewsCorrCoef + >>> from torchmetrics.classification import MultilabelMatthewsCorrCoef >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) @@ -270,6 +271,29 @@ class MatthewsCorrCoef(Metric): full_state_update: bool = False confmat: Tensor + def __new__( + cls, + num_classes: int, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryMatthewsCorrCoef(threshold, **kwargs) + if task == "multiclass": + return MulticlassMatthewsCorrCoef(num_classes, **kwargs) + if task == "multilabel": + return MultilabelMatthewsCorrCoef(num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: int, diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 7dc12125c0a..49f03d84c92 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification.stat_scores import ( BinaryStatScores, @@ -67,7 +68,7 @@ class BinaryPrecision(BinaryStatScores): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinaryPrecision + >>> from torchmetrics.classification import BinaryPrecision >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryPrecision() @@ -75,7 +76,7 @@ class BinaryPrecision(BinaryStatScores): tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics import BinaryPrecision + >>> from torchmetrics.classification import BinaryPrecision >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryPrecision() @@ -83,7 +84,7 @@ class BinaryPrecision(BinaryStatScores): tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics import BinaryPrecision + >>> from torchmetrics.classification import BinaryPrecision >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -162,7 +163,7 @@ class MulticlassPrecision(MulticlassStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassPrecision + >>> from torchmetrics.classification import MulticlassPrecision >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassPrecision(num_classes=3) @@ -173,7 +174,7 @@ class MulticlassPrecision(MulticlassStatScores): tensor([1.0000, 0.5000, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassPrecision + >>> from torchmetrics.classification import MulticlassPrecision >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -189,7 +190,7 @@ class MulticlassPrecision(MulticlassStatScores): tensor([1.0000, 0.5000, 1.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassPrecision + >>> from torchmetrics.classification import MulticlassPrecision >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassPrecision(num_classes=3, multidim_average='samplewise') @@ -266,7 +267,7 @@ class MultilabelPrecision(MultilabelStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelPrecision + >>> from torchmetrics.classification import MultilabelPrecision >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelPrecision(num_labels=3) @@ -277,7 +278,7 @@ class MultilabelPrecision(MultilabelStatScores): tensor([1.0000, 0.0000, 0.5000]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelPrecision + >>> from torchmetrics.classification import MultilabelPrecision >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelPrecision(num_labels=3) @@ -288,7 +289,7 @@ class MultilabelPrecision(MultilabelStatScores): tensor([1.0000, 0.0000, 0.5000]) Example (multidim tensors): - >>> from torchmetrics import MultilabelPrecision + >>> from torchmetrics.classification import MultilabelPrecision >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -353,7 +354,7 @@ class BinaryRecall(BinaryStatScores): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinaryRecall + >>> from torchmetrics.classification import BinaryRecall >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryRecall() @@ -361,7 +362,7 @@ class BinaryRecall(BinaryStatScores): tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics import BinaryRecall + >>> from torchmetrics.classification import BinaryRecall >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryRecall() @@ -369,7 +370,7 @@ class BinaryRecall(BinaryStatScores): tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics import BinaryRecall + >>> from torchmetrics.classification import BinaryRecall >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -448,7 +449,7 @@ class MulticlassRecall(MulticlassStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassRecall + >>> from torchmetrics.classification import MulticlassRecall >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassRecall(num_classes=3) @@ -459,7 +460,7 @@ class MulticlassRecall(MulticlassStatScores): tensor([0.5000, 1.0000, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassRecall + >>> from torchmetrics.classification import MulticlassRecall >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -475,7 +476,7 @@ class MulticlassRecall(MulticlassStatScores): tensor([0.5000, 1.0000, 1.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassRecall + >>> from torchmetrics.classification import MulticlassRecall >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassRecall(num_classes=3, multidim_average='samplewise') @@ -552,7 +553,7 @@ class MultilabelRecall(MultilabelStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelRecall + >>> from torchmetrics.classification import MultilabelRecall >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelRecall(num_labels=3) @@ -563,7 +564,7 @@ class MultilabelRecall(MultilabelStatScores): tensor([1., 0., 1.]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelRecall + >>> from torchmetrics.classification import MultilabelRecall >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelRecall(num_labels=3) @@ -574,7 +575,7 @@ class MultilabelRecall(MultilabelStatScores): tensor([1., 0., 1.]) Example (multidim tensors): - >>> from torchmetrics import MultilabelRecall + >>> from torchmetrics.classification import MultilabelRecall >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -698,6 +699,36 @@ class Precision(StatScores): higher_is_better = True full_state_update: bool = False + def __new__( + cls, + threshold: float = 0.5, + num_classes: Optional[int] = None, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryPrecision(threshold, **kwargs) + if task == "multiclass": + return MulticlassPrecision(num_classes, top_k, average, **kwargs) + if task == "multilabel": + return MultilabelPrecision(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, @@ -837,6 +868,36 @@ class Recall(StatScores): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + threshold: float = 0.5, + num_classes: Optional[int] = None, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryRecall(threshold, **kwargs) + if task == "multiclass": + return MulticlassRecall(num_classes, top_k, average, **kwargs) + if task == "multilabel": + return MultilabelRecall(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index eeaf8bbc27b..a8e9d8c1823 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _adjust_threshold_arg, @@ -86,7 +87,7 @@ class BinaryPrecisionRecallCurve(Metric): - thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values Example: - >>> from torchmetrics import BinaryPrecisionRecallCurve + >>> from torchmetrics.classification import BinaryPrecisionRecallCurve >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryPrecisionRecallCurve(thresholds=None) @@ -197,7 +198,7 @@ class MulticlassPrecisionRecallCurve(Metric): then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. Example: - >>> from torchmetrics import MulticlassPrecisionRecallCurve + >>> from torchmetrics.classification import MulticlassPrecisionRecallCurve >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -330,7 +331,7 @@ class MultilabelPrecisionRecallCurve(Metric): then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. Example: - >>> from torchmetrics import MultilabelPrecisionRecallCurve + >>> from torchmetrics.classification import MultilabelPrecisionRecallCurve >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -470,6 +471,30 @@ class PrecisionRecallCurve(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryPrecisionRecallCurve(**kwargs) + if task == "multiclass": + return MulticlassPrecisionRecallCurve(num_classes, **kwargs) + if task == "multilabel": + return MultilabelPrecisionRecallCurve(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index 3f72ab0a5d5..6c00b9070bd 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -56,7 +56,7 @@ class MultilabelCoverageError(Metric): Set to ``False`` for faster computations. Example: - >>> from torchmetrics import MultilabelCoverageError + >>> from torchmetrics.classification import MultilabelCoverageError >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) @@ -126,7 +126,7 @@ class MultilabelRankingAveragePrecision(Metric): Set to ``False`` for faster computations. Example: - >>> from torchmetrics import MultilabelRankingAveragePrecision + >>> from torchmetrics.classification import MultilabelRankingAveragePrecision >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) @@ -198,7 +198,7 @@ class MultilabelRankingLoss(Metric): Set to ``False`` for faster computations. Example: - >>> from torchmetrics import MultilabelRankingLoss + >>> from torchmetrics.classification import MultilabelRankingLoss >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) diff --git a/src/torchmetrics/classification/recall_at_fixed_precision.py b/src/torchmetrics/classification/recall_at_fixed_precision.py index 81a4a534a8e..b742f6146c1 100644 --- a/src/torchmetrics/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/classification/recall_at_fixed_precision.py @@ -78,7 +78,7 @@ class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): - threshold: an scalar tensor with the corresponding threshold level Example: - >>> from torchmetrics import BinaryRecallAtFixedPrecision + >>> from torchmetrics.classification import BinaryRecallAtFixedPrecision >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryRecallAtFixedPrecision(min_precision=0.5, thresholds=None) @@ -161,7 +161,7 @@ class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve): - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class Example: - >>> from torchmetrics import MulticlassRecallAtFixedPrecision + >>> from torchmetrics.classification import MulticlassRecallAtFixedPrecision >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -252,7 +252,7 @@ class MultilabelRecallAtFixedPrecision(MultilabelPrecisionRecallCurve): - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class Example: - >>> from torchmetrics import MultilabelRecallAtFixedPrecision + >>> from torchmetrics.classification import MultilabelRecallAtFixedPrecision >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index b8f5f96d336..9f30e1691e8 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification.precision_recall_curve import ( BinaryPrecisionRecallCurve, @@ -82,7 +83,7 @@ class BinaryROC(BinaryPrecisionRecallCurve): - thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values Example: - >>> from torchmetrics import BinaryROC + >>> from torchmetrics.classification import BinaryROC >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> metric = BinaryROC(thresholds=None) @@ -164,7 +165,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. Example: - >>> from torchmetrics import MulticlassROC + >>> from torchmetrics.classification import MulticlassROC >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -262,7 +263,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. Example: - >>> from torchmetrics import MultilabelROC + >>> from torchmetrics.classification import MultilabelROC >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -394,6 +395,30 @@ class ROC(Metric): preds: List[Tensor] target: List[Tensor] + def __new__( + cls, + num_classes: Optional[int] = None, + pos_label: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update(dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args)) + if task == "binary": + return BinaryROC(**kwargs) + if task == "multiclass": + return MulticlassROC(num_classes, **kwargs) + if task == "multilabel": + return MultilabelROC(num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index fed032471a9..e087a0e0077 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification.stat_scores import ( BinaryStatScores, @@ -63,7 +64,7 @@ class BinarySpecificity(BinaryStatScores): is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics import BinarySpecificity + >>> from torchmetrics.classification import BinarySpecificity >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinarySpecificity() @@ -71,7 +72,7 @@ class BinarySpecificity(BinaryStatScores): tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics import BinarySpecificity + >>> from torchmetrics.classification import BinarySpecificity >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinarySpecificity() @@ -79,7 +80,7 @@ class BinarySpecificity(BinaryStatScores): tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics import BinarySpecificity + >>> from torchmetrics.classification import BinarySpecificity >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -153,7 +154,7 @@ class MulticlassSpecificity(MulticlassStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MulticlassSpecificity + >>> from torchmetrics.classification import MulticlassSpecificity >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassSpecificity(num_classes=3) @@ -164,7 +165,7 @@ class MulticlassSpecificity(MulticlassStatScores): tensor([1.0000, 0.6667, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassSpecificity + >>> from torchmetrics.classification import MulticlassSpecificity >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -180,7 +181,7 @@ class MulticlassSpecificity(MulticlassStatScores): tensor([1.0000, 0.6667, 1.0000]) Example (multidim tensors): - >>> from torchmetrics import MulticlassSpecificity + >>> from torchmetrics.classification import MulticlassSpecificity >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise') @@ -247,7 +248,7 @@ class MultilabelSpecificity(MultilabelStatScores): - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics import MultilabelSpecificity + >>> from torchmetrics.classification import MultilabelSpecificity >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelSpecificity(num_labels=3) @@ -258,7 +259,7 @@ class MultilabelSpecificity(MultilabelStatScores): tensor([1., 1., 0.]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelSpecificity + >>> from torchmetrics.classification import MultilabelSpecificity >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelSpecificity(num_labels=3) @@ -269,7 +270,7 @@ class MultilabelSpecificity(MultilabelStatScores): tensor([1., 1., 0.]) Example (multidim tensors): - >>> from torchmetrics import MultilabelSpecificity + >>> from torchmetrics.classification import MultilabelSpecificity >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -388,6 +389,36 @@ class Specificity(StatScores): higher_is_better: bool = True full_state_update: bool = False + def __new__( + cls, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinarySpecificity(threshold, **kwargs) + if task == "multiclass": + return MulticlassSpecificity(num_classes, average, top_k, **kwargs) + if task == "multilabel": + return MultilabelSpecificity(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, num_classes: Optional[int] = None, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 01347cf9617..39406255441 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -39,6 +39,7 @@ from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn class _AbstractStatScores(Metric): @@ -110,7 +111,7 @@ class BinaryStatScores(_AbstractStatScores): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import BinaryStatScores + >>> from torchmetrics.classification import BinaryStatScores >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> metric = BinaryStatScores() @@ -118,7 +119,7 @@ class BinaryStatScores(_AbstractStatScores): tensor([2, 1, 2, 1, 3]) Example (preds is float tensor): - >>> from torchmetrics import BinaryStatScores + >>> from torchmetrics.classification import BinaryStatScores >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> metric = BinaryStatScores() @@ -126,7 +127,7 @@ class BinaryStatScores(_AbstractStatScores): tensor([2, 1, 2, 1, 3]) Example (multidim tensors): - >>> from torchmetrics import BinaryStatScores + >>> from torchmetrics.classification import BinaryStatScores >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -151,7 +152,7 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: - super().__init__(**kwargs) + super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) self.threshold = threshold @@ -230,7 +231,7 @@ class MulticlassStatScores(_AbstractStatScores): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import MulticlassStatScores + >>> from torchmetrics.classification import MulticlassStatScores >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassStatScores(num_classes=3, average='micro') @@ -243,7 +244,7 @@ class MulticlassStatScores(_AbstractStatScores): [1, 0, 3, 0, 1]]) Example (preds is float tensor): - >>> from torchmetrics import MulticlassStatScores + >>> from torchmetrics.classification import MulticlassStatScores >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -261,7 +262,7 @@ class MulticlassStatScores(_AbstractStatScores): [1, 0, 3, 0, 1]]) Example (multidim tensors): - >>> from torchmetrics import MulticlassStatScores + >>> from torchmetrics.classification import MulticlassStatScores >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> metric = MulticlassStatScores(num_classes=3, multidim_average="samplewise", average='micro') @@ -291,7 +292,7 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: - super().__init__(**kwargs) + super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) self.num_classes = num_classes @@ -379,7 +380,7 @@ class MultilabelStatScores(_AbstractStatScores): kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics import MultilabelStatScores + >>> from torchmetrics.classification import MultilabelStatScores >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> metric = MultilabelStatScores(num_labels=3, average='micro') @@ -392,7 +393,7 @@ class MultilabelStatScores(_AbstractStatScores): [1, 1, 0, 0, 1]]) Example (preds is float tensor): - >>> from torchmetrics import MultilabelStatScores + >>> from torchmetrics.classification import MultilabelStatScores >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> metric = MultilabelStatScores(num_labels=3, average='micro') @@ -405,7 +406,7 @@ class MultilabelStatScores(_AbstractStatScores): [1, 1, 0, 0, 1]]) Example (multidim tensors): - >>> from torchmetrics import MultilabelStatScores + >>> from torchmetrics.classification import MultilabelStatScores >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -441,7 +442,7 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: - super().__init__(**kwargs) + super(_AbstractStatScores, self).__init__(**kwargs) if validate_args: _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) self.num_labels = num_labels @@ -489,9 +490,6 @@ def compute(self) -> Tensor: return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -# -------------------------- Old stuff -------------------------- - - class StatScores(Metric): r"""Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors`_ and the `confusion matrix`_. @@ -590,6 +588,36 @@ class StatScores(Metric): # tn: Union[Tensor, List[Tensor]] # fn: Union[Tensor, List[Tensor]] + def __new__( + cls, + num_classes: Optional[int] = None, + threshold: float = 0.5, + average: Optional[str] = "micro", + mdmc_average: Optional[str] = None, + ignore_index: Optional[int] = None, + top_k: Optional[int] = None, + multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + return BinaryStatScores(threshold, **kwargs) + if task == "multiclass": + return MulticlassStatScores(num_classes, average, top_k, **kwargs) + if task == "multilabel": + return MultilabelStatScores(num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + return super().__new__(cls) + def __init__( self, threshold: float = 0.5, @@ -599,9 +627,38 @@ def __init__( ignore_index: Optional[int] = None, mdmc_reduce: Optional[str] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, **kwargs: Any, ) -> None: - super().__init__(**kwargs) + self.task = task + if self.task is not None: + kwargs.update( + dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + ) + if task == "binary": + BinaryStatScores.__init__(self, threshold, **kwargs) + if task == "multiclass": + MulticlassStatScores.__init__(self, num_classes, top_k, average, **kwargs) + if task == "multilabel": + MultilabelStatScores.__init__(self, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + + Metric.__init__(self, **kwargs) self.reduce = reduce self.mdmc_reduce = mdmc_reduce @@ -647,6 +704,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore preds: Predictions from model (probabilities, logits or labels) target: Ground truth values """ + if self.task is not None: + if self.task == "binary": + BinaryStatScores.update(self, preds, target) + elif self.task == "multiclass": + MulticlassStatScores.update(self, preds, target) + elif self.task == "multilabel": + MultilabelStatScores.update(self, preds, target) + return tp, fp, tn, fn = _stat_scores_update( preds, @@ -711,5 +776,12 @@ def compute(self) -> Tensor: - If ``reduce='macro'``, the shape will be ``(N, C, 5)`` - If ``reduce='samples'``, the shape will be ``(N, X, 5)`` """ + if self.task is not None: + if self.task == "binary": + return BinaryStatScores.compute(self) + elif self.task == "multiclass": + return MulticlassStatScores.compute(self) + elif self.task == "multilabel": + return MultilabelStatScores.compute(self) tp, fp, tn, fn = self._get_final_stats() return _stat_scores_compute(tp, fp, tn, fn) diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index c3adf63a33e..8e71ed4de63 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -14,106 +14,29 @@ from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio -from torchmetrics.functional.classification.accuracy import ( - accuracy, - binary_accuracy, - multiclass_accuracy, - multilabel_accuracy, -) +from torchmetrics.functional.classification.accuracy import accuracy from torchmetrics.functional.classification.auc import auc -from torchmetrics.functional.classification.auroc import auroc, binary_auroc, multiclass_auroc, multilabel_auroc -from torchmetrics.functional.classification.average_precision import ( - average_precision, - binary_average_precision, - multiclass_average_precision, - multilabel_average_precision, -) -from torchmetrics.functional.classification.calibration_error import ( - binary_calibration_error, - calibration_error, - multiclass_calibration_error, -) -from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa -from torchmetrics.functional.classification.confusion_matrix import ( - binary_confusion_matrix, - confusion_matrix, - multiclass_confusion_matrix, - multilabel_confusion_matrix, -) +from torchmetrics.functional.classification.auroc import auroc +from torchmetrics.functional.classification.average_precision import average_precision +from torchmetrics.functional.classification.calibration_error import calibration_error +from torchmetrics.functional.classification.cohen_kappa import cohen_kappa +from torchmetrics.functional.classification.confusion_matrix import confusion_matrix from torchmetrics.functional.classification.dice import dice, dice_score -from torchmetrics.functional.classification.exact_match import multilabel_exact_match -from torchmetrics.functional.classification.f_beta import ( - binary_f1_score, - binary_fbeta_score, - f1_score, - fbeta_score, - multiclass_f1_score, - multiclass_fbeta_score, - multilabel_f1_score, - multilabel_fbeta_score, -) -from torchmetrics.functional.classification.hamming import ( - binary_hamming_distance, - hamming_distance, - multiclass_hamming_distance, - multilabel_hamming_distance, -) -from torchmetrics.functional.classification.hinge import binary_hinge_loss, hinge_loss, multiclass_hinge_loss -from torchmetrics.functional.classification.jaccard import ( - binary_jaccard_index, - jaccard_index, - multiclass_jaccard_index, - multilabel_jaccard_index, -) -from torchmetrics.functional.classification.matthews_corrcoef import ( - binary_matthews_corrcoef, - matthews_corrcoef, - multiclass_matthews_corrcoef, - multilabel_matthews_corrcoef, -) -from torchmetrics.functional.classification.precision_recall import ( - binary_precision, - binary_recall, - multiclass_precision, - multiclass_recall, - multilabel_precision, - multilabel_recall, - precision, - precision_recall, - recall, -) -from torchmetrics.functional.classification.precision_recall_curve import ( - binary_precision_recall_curve, - multiclass_precision_recall_curve, - multilabel_precision_recall_curve, - precision_recall_curve, -) +from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score +from torchmetrics.functional.classification.hamming import hamming_distance +from torchmetrics.functional.classification.hinge import hinge_loss +from torchmetrics.functional.classification.jaccard import jaccard_index +from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef +from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall +from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve from torchmetrics.functional.classification.ranking import ( coverage_error, label_ranking_average_precision, label_ranking_loss, - multilabel_coverage_error, - multilabel_ranking_average_precision, - multilabel_ranking_loss, -) -from torchmetrics.functional.classification.recall_at_fixed_precision import ( - binary_recall_at_fixed_precision, - multiclass_recall_at_fixed_precision, - multilabel_recall_at_fixed_precision, -) -from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc, roc -from torchmetrics.functional.classification.specificity import ( - binary_specificity, - multiclass_specificity, - multilabel_specificity, - specificity, -) -from torchmetrics.functional.classification.stat_scores import ( - binary_stat_scores, - multiclass_stat_scores, - multilabel_stat_scores, - stat_scores, ) +from torchmetrics.functional.classification.roc import roc +from torchmetrics.functional.classification.specificity import specificity +from torchmetrics.functional.classification.stat_scores import stat_scores from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.functional.image.gradients import image_gradients @@ -248,63 +171,4 @@ "word_error_rate", "word_information_lost", "word_information_preserved", -] + [ - "binary_confusion_matrix", - "multiclass_confusion_matrix", - "multilabel_confusion_matrix", - "binary_stat_scores", - "multiclass_stat_scores", - "multilabel_stat_scores", - "binary_f1_score", - "binary_fbeta_score", - "multiclass_f1_score", - "multiclass_fbeta_score", - "multilabel_f1_score", - "multilabel_fbeta_score", - "binary_cohen_kappa", - "multiclass_cohen_kappa", - "binary_jaccard_index", - "multiclass_jaccard_index", - "multilabel_jaccard_index", - "binary_matthews_corrcoef", - "multiclass_matthews_corrcoef", - "multilabel_matthews_corrcoef", - "multilabel_coverage_error", - "multilabel_ranking_average_precision", - "multilabel_ranking_loss", - "binary_accuracy", - "multilabel_accuracy", - "multiclass_accuracy", - "binary_specificity", - "multiclass_specificity", - "multilabel_specificity", - "binary_hamming_distance", - "multiclass_hamming_distance", - "multilabel_hamming_distance", - "binary_precision", - "multiclass_precision", - "multilabel_precision", - "binary_recall", - "multiclass_recall", - "multilabel_recall", - "multilabel_exact_match", - "binary_auroc", - "multiclass_auroc", - "multilabel_auroc", - "binary_average_precision", - "multiclass_average_precision", - "multilabel_average_precision", - "binary_precision_recall_curve", - "multiclass_precision_recall_curve", - "multilabel_precision_recall_curve", - "binary_recall_at_fixed_precision", - "multiclass_recall_at_fixed_precision", - "multilabel_recall_at_fixed_precision", - "binary_roc", - "multiclass_roc", - "multilabel_roc", - "binary_calibration_error", - "multiclass_calibration_error", - "binary_hinge_loss", - "multiclass_hinge_loss", ] diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 0c791e55981..82932c0d6e3 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -35,7 +35,11 @@ calibration_error, multiclass_calibration_error, ) -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 +from torchmetrics.functional.classification.cohen_kappa import ( # noqa: F401 + binary_cohen_kappa, + cohen_kappa, + multiclass_cohen_kappa, +) from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 binary_confusion_matrix, confusion_matrix, @@ -65,8 +69,18 @@ hinge_loss, multiclass_hinge_loss, ) -from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 -from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 +from torchmetrics.functional.classification.jaccard import ( # noqa: F401 + binary_jaccard_index, + jaccard_index, + multiclass_jaccard_index, + multilabel_jaccard_index, +) +from torchmetrics.functional.classification.matthews_corrcoef import ( # noqa: F401 + binary_matthews_corrcoef, + matthews_corrcoef, + multiclass_matthews_corrcoef, + multilabel_matthews_corrcoef, +) from torchmetrics.functional.classification.precision_recall import ( # noqa: F401 binary_precision, binary_recall, diff --git a/src/torchmetrics/functional/classification/accuracy.py b/src/torchmetrics/functional/classification/accuracy.py index c2b313ce1c6..88a5174bdd8 100644 --- a/src/torchmetrics/functional/classification/accuracy.py +++ b/src/torchmetrics/functional/classification/accuracy.py @@ -36,6 +36,7 @@ from torchmetrics.utilities.checks import _check_classification_inputs, _input_format_classification, _input_squeeze from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn def _accuracy_reduce( @@ -134,21 +135,21 @@ def binary_accuracy( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_accuracy + >>> from torchmetrics.functional.classification import binary_accuracy >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_accuracy(preds, target) tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_accuracy + >>> from torchmetrics.functional.classification import binary_accuracy >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_accuracy(preds, target) tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics.functional import binary_accuracy + >>> from torchmetrics.functional.classification import binary_accuracy >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -236,7 +237,7 @@ def multiclass_accuracy( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_accuracy + >>> from torchmetrics.functional.classification import multiclass_accuracy >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_accuracy(preds, target, num_classes=3) @@ -245,7 +246,7 @@ def multiclass_accuracy( tensor([0.5000, 1.0000, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_accuracy + >>> from torchmetrics.functional.classification import multiclass_accuracy >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -259,7 +260,7 @@ def multiclass_accuracy( tensor([0.5000, 1.0000, 1.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_accuracy + >>> from torchmetrics.functional.classification import multiclass_accuracy >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_accuracy(preds, target, num_classes=3, multidim_average='samplewise') @@ -343,7 +344,7 @@ def multilabel_accuracy( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_accuracy + >>> from torchmetrics.functional.classification import multilabel_accuracy >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_accuracy(preds, target, num_labels=3) @@ -352,7 +353,7 @@ def multilabel_accuracy( tensor([1.0000, 0.5000, 0.5000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_accuracy + >>> from torchmetrics.functional.classification import multilabel_accuracy >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_accuracy(preds, target, num_labels=3) @@ -361,7 +362,7 @@ def multilabel_accuracy( tensor([1.0000, 0.5000, 0.5000]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_accuracy + >>> from torchmetrics.functional.classification import multilabel_accuracy >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -632,8 +633,20 @@ def accuracy( num_classes: Optional[int] = None, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Accuracy`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Accuracy`_ .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) @@ -756,6 +769,27 @@ def accuracy( >>> accuracy(preds, target, top_k=2) tensor(0.6667) """ + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_accuracy(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_accuracy(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_accuracy(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 34d6b28633b..85a8c8f8313 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -170,7 +170,7 @@ def binary_auroc( A single scalar with the auroc score Example: - >>> from torchmetrics.functional import binary_auroc + >>> from torchmetrics.functional.classification import binary_auroc >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_auroc(preds, target, thresholds=None) @@ -273,7 +273,7 @@ def multiclass_auroc( If `average="macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics.functional import multiclass_auroc + >>> from torchmetrics.functional.classification import multiclass_auroc >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -401,7 +401,7 @@ def multilabel_auroc( If `average="micro|macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics.functional import multilabel_auroc + >>> from torchmetrics.functional.classification import multilabel_auroc >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -608,8 +608,21 @@ def auroc( average: Optional[str] = "macro", max_fpr: Optional[float] = None, sample_weights: Optional[Sequence] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: - """Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Compute Area Under the Receiver Operating Characteristic Curve (`ROC AUC`_) For non-binary input, if the ``preds`` and ``target`` tensor have the same size the input will be interpretated as multilabel and if ``preds`` have one @@ -672,5 +685,26 @@ def auroc( >>> auroc(preds, target, num_classes=3) tensor(0.7778) """ + if task is not None: + kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_auroc(preds, target, max_fpr, **kwargs) + if task == "multiclass": + return multiclass_auroc(preds, target, num_classes, average, **kwargs) + if task == "multilabel": + return multilabel_auroc(preds, target, num_labels, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + preds, target, mode = _auroc_update(preds, target) return _auroc_compute(preds, target, mode, num_classes, pos_label, average, max_fpr, sample_weights) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 85b2d97fc35..4bff1961f1c 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -136,7 +136,7 @@ def binary_average_precision( A single scalar with the average precision score Example: - >>> from torchmetrics.functional import binary_average_precision + >>> from torchmetrics.functional.classification import binary_average_precision >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_average_precision(preds, target, thresholds=None) @@ -245,7 +245,7 @@ def multiclass_average_precision( If `average="macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics.functional import multiclass_average_precision + >>> from torchmetrics.functional.classification import multiclass_average_precision >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -377,7 +377,7 @@ def multilabel_average_precision( If `average="micro|macro"|"weighted"` then a single scalar is returned. Example: - >>> from torchmetrics.functional import multilabel_average_precision + >>> from torchmetrics.functional.classification import multilabel_average_precision >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -563,8 +563,21 @@ def average_precision( num_classes: Optional[int] = None, pos_label: Optional[int] = None, average: Optional[str] = "macro", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Union[List[Tensor], Tensor]: - """Computes the average precision score. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the average precision score. Args: preds: predictions from model (logits or probabilities) @@ -607,5 +620,25 @@ def average_precision( >>> average_precision(pred, target, num_classes=5, average=None) [tensor(1.), tensor(1.), tensor(0.2500), tensor(0.2500), tensor(nan)] """ + if task is not None: + kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_average_precision(preds, target, **kwargs) + if task == "multiclass": + return multiclass_average_precision(preds, target, num_classes, average, **kwargs) + if task == "multilabel": + return multilabel_average_precision(preds, target, num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) preds, target, num_classes, pos_label = _average_precision_update(preds, target, num_classes, pos_label, average) return _average_precision_compute(preds, target, num_classes, pos_label, average) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index d56eeb9190e..e4698d4de47 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -26,6 +26,7 @@ from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.utilities.prints import rank_zero_warn def _binning_with_loop( @@ -216,7 +217,7 @@ def binary_calibration_error( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import binary_calibration_error + >>> from torchmetrics.functional.classification import binary_calibration_error >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> binary_calibration_error(preds, target, n_bins=2, norm='l1') @@ -325,7 +326,7 @@ def multiclass_calibration_error( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import multiclass_calibration_error + >>> from torchmetrics.functional.classification import multiclass_calibration_error >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], @@ -388,8 +389,25 @@ def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: return confidences.float(), accuracies.float() -def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str = "l1") -> Tensor: - r"""`Computes the Top-label Calibration Error`_ +def calibration_error( + preds: Tensor, + target: Tensor, + n_bins: int = 15, + norm: str = "l1", + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + `Computes the Top-label Calibration Error`_ Three different norms are implemented, each corresponding to variations on the calibration error metric. @@ -422,6 +440,23 @@ def calibration_error(preds: Tensor, target: Tensor, n_bins: int = 15, norm: str norm: Norm used to compare empirical and expected probability bins. Defaults to "l1", or Expected Calibration Error. """ + if task is not None: + kwargs = dict(norm=norm, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_calibration_error(preds, target, n_bins, **kwargs) + if task == "multiclass": + return multiclass_calibration_error(preds, target, num_classes, n_bins, **kwargs) + raise ValueError(f"Expected argument `task` to either be `'binary'`, `'multiclass'` but got {task}") + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + if norm not in ("l1", "l2", "max"): raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index c72ab4a696b..5d3d8a064c6 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -29,6 +29,7 @@ _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, ) +from torchmetrics.utilities.prints import rank_zero_warn def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: @@ -79,8 +80,8 @@ def binary_cohen_kappa( preds: Tensor, target: Tensor, threshold: float = 0.5, - ignore_index: Optional[int] = None, weights: Optional[Literal["linear", "quadratic", "none"]] = None, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for binary @@ -107,26 +108,26 @@ class labels. preds: Tensor with predictions target: Tensor with true labels threshold: Threshold for transforming probability to binary (0,1) predictions - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation weights: Weighting type to calculate the score. Choose from: - ``None`` or ``'none'``: no weighting - ``'linear'``: linear weighting - ``'quadratic'``: quadratic weighting + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_cohen_kappa + >>> from torchmetrics.functional.classification import binary_cohen_kappa >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> binary_cohen_kappa(preds, target) tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_cohen_kappa + >>> from torchmetrics.functional.classification import binary_cohen_kappa >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_cohen_kappa(preds, target) @@ -162,8 +163,8 @@ def multiclass_cohen_kappa( preds: Tensor, target: Tensor, num_classes: int, - ignore_index: Optional[int] = None, weights: Optional[Literal["linear", "quadratic", "none"]] = None, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass @@ -190,27 +191,28 @@ class labels. preds: Tensor with predictions target: Tensor with true labels num_classes: Integer specifing the number of classes - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation weights: Weighting type to calculate the score. Choose from: - ``None`` or ``'none'``: no weighting - ``'linear'``: linear weighting - ``'quadratic'``: quadratic weighting + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics.functional import multiclass_cohen_kappa + >>> from torchmetrics.functional.classification import multiclass_cohen_kappa >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_cohen_kappa(preds, target, num_classes=3) tensor(0.6364) Example (pred is float tensor): - >>> from torchmetrics.functional import multiclass_cohen_kappa + >>> from torchmetrics.functional.classification import multiclass_cohen_kappa >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -287,8 +289,20 @@ def cohen_kappa( num_classes: int, weights: Optional[str] = None, threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: - r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + + Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as @@ -320,5 +334,24 @@ class labels. >>> cohen_kappa(preds, target, num_classes=2) tensor(0.5000) """ + if task is not None: + kwargs = dict(weights=weights, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_cohen_kappa(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_cohen_kappa(preds, target, num_classes, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) + confmat = _cohen_kappa_update(preds, target, num_classes, threshold) return _cohen_kappa_compute(confmat, weights) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 4d841d37f8c..41cbc67ffa5 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -162,8 +162,8 @@ def binary_confusion_matrix( preds: Tensor, target: Tensor, threshold: float = 0.5, - ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r""" @@ -182,14 +182,14 @@ def binary_confusion_matrix( preds: Tensor with predictions target: Tensor with true labels threshold: Threshold for transforming probability to binary (0,1) predictions - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. @@ -197,7 +197,7 @@ def binary_confusion_matrix( A ``[2, 2]`` tensor Example (preds is int tensor): - >>> from torchmetrics.functional import binary_confusion_matrix + >>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> binary_confusion_matrix(preds, target) @@ -205,7 +205,7 @@ def binary_confusion_matrix( [1, 1]]) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_confusion_matrix + >>> from torchmetrics.functional.classification import binary_confusion_matrix >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_confusion_matrix(preds, target) @@ -347,8 +347,8 @@ def multiclass_confusion_matrix( preds: Tensor, target: Tensor, num_classes: int, - ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r""" @@ -367,14 +367,14 @@ def multiclass_confusion_matrix( preds: Tensor with predictions target: Tensor with true labels num_classes: Integer specifing the number of classes - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. @@ -382,7 +382,7 @@ def multiclass_confusion_matrix( A ``[num_classes, num_classes]`` tensor Example (pred is integer tensor): - >>> from torchmetrics.functional import multiclass_confusion_matrix + >>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) @@ -391,7 +391,7 @@ def multiclass_confusion_matrix( [0, 0, 1]]) Example (pred is float tensor): - >>> from torchmetrics.functional import multiclass_confusion_matrix + >>> from torchmetrics.functional.classification import multiclass_confusion_matrix >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -534,8 +534,8 @@ def multilabel_confusion_matrix( target: Tensor, num_labels: int, threshold: float = 0.5, - ignore_index: Optional[int] = None, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r""" @@ -555,14 +555,14 @@ def multilabel_confusion_matrix( target: Tensor with true labels num_labels: Integer specifing the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. @@ -570,7 +570,7 @@ def multilabel_confusion_matrix( A ``[num_labels, 2, 2]`` tensor Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_confusion_matrix + >>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) @@ -579,7 +579,7 @@ def multilabel_confusion_matrix( [[0, 1], [0, 1]]]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_confusion_matrix + >>> from torchmetrics.functional.classification import multilabel_confusion_matrix >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_confusion_matrix(preds, target, num_labels=3) @@ -696,8 +696,20 @@ def confusion_matrix( normalize: Optional[str] = None, threshold: float = 0.5, multilabel: bool = False, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the `confusion matrix`_. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. Works with multi-dimensional preds and target, but it should be noted that @@ -758,12 +770,25 @@ def confusion_matrix( [[0, 1], [0, 1]]]) """ - rank_zero_warn( - "`torchmetrics.functional.confusion_matrix` have been deprecated in v0.10 in favor of" - "`torchmetrics.functional.binary_confusion_matrix`, `torchmetrics.functional.multiclass_confusion_matrix`" - "and `torchmetrics.functional.multilabel_confusion_matrix`. Please upgrade to the version that matches" - "your problem (API may have changed). This function will be removed v0.11.", - DeprecationWarning, - ) + if task is not None: + kwargs = dict(normalize=normalize, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_confusion_matrix(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_confusion_matrix(preds, target, num_classes, **kwargs) + if task == "multilabel": + return multilabel_confusion_matrix(preds, target, num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel) return _confusion_matrix_compute(confmat, normalize) diff --git a/src/torchmetrics/functional/classification/exact_match.py b/src/torchmetrics/functional/classification/exact_match.py index ad1838ad8ef..36a024f3b4b 100644 --- a/src/torchmetrics/functional/classification/exact_match.py +++ b/src/torchmetrics/functional/classification/exact_match.py @@ -100,21 +100,21 @@ def multilabel_exact_match( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_exact_match + >>> from torchmetrics.functional.classification import multilabel_exact_match >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_exact_match(preds, target, num_labels=3) tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_exact_match + >>> from torchmetrics.functional.classification import multilabel_exact_match >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_exact_match(preds, target, num_labels=3) tensor(0.5000) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_exact_match + >>> from torchmetrics.functional.classification import multilabel_exact_match >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index 81012d49df5..bea2e88a6f1 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -36,6 +36,7 @@ from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod as AvgMethod from torchmetrics.utilities.enums import MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn def _fbeta_reduce( @@ -125,21 +126,21 @@ def binary_fbeta_score( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_fbeta_score + >>> from torchmetrics.functional.classification import binary_fbeta_score >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_fbeta_score(preds, target, beta=2.0) tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_fbeta_score + >>> from torchmetrics.functional.classification import binary_fbeta_score >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_fbeta_score(preds, target, beta=2.0) tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics.functional import binary_fbeta_score + >>> from torchmetrics.functional.classification import binary_fbeta_score >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -240,7 +241,7 @@ def multiclass_fbeta_score( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_fbeta_score + >>> from torchmetrics.functional.classification import multiclass_fbeta_score >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3) @@ -249,7 +250,7 @@ def multiclass_fbeta_score( tensor([0.5556, 0.8333, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_fbeta_score + >>> from torchmetrics.functional.classification import multiclass_fbeta_score >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -263,7 +264,7 @@ def multiclass_fbeta_score( tensor([0.5556, 0.8333, 1.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_fbeta_score + >>> from torchmetrics.functional.classification import multiclass_fbeta_score >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_fbeta_score(preds, target, beta=2.0, num_classes=3, multidim_average='samplewise') @@ -361,7 +362,7 @@ def multilabel_fbeta_score( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_fbeta_score + >>> from torchmetrics.functional.classification import multilabel_fbeta_score >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3) @@ -370,7 +371,7 @@ def multilabel_fbeta_score( tensor([1.0000, 0.0000, 0.8333]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_fbeta_score + >>> from torchmetrics.functional.classification import multilabel_fbeta_score >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_fbeta_score(preds, target, beta=2.0, num_labels=3) @@ -379,7 +380,7 @@ def multilabel_fbeta_score( tensor([1.0000, 0.0000, 0.8333]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_fbeta_score + >>> from torchmetrics.functional.classification import multilabel_fbeta_score >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -447,21 +448,21 @@ def binary_f1_score( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_f1_score + >>> from torchmetrics.functional.classification import binary_f1_score >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_f1_score(preds, target) tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_f1_score + >>> from torchmetrics.functional.classification import binary_f1_score >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_f1_score(preds, target) tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics.functional import binary_f1_score + >>> from torchmetrics.functional.classification import binary_f1_score >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -549,7 +550,7 @@ def multiclass_f1_score( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_f1_score + >>> from torchmetrics.functional.classification import multiclass_f1_score >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_f1_score(preds, target, num_classes=3) @@ -558,7 +559,7 @@ def multiclass_f1_score( tensor([0.6667, 0.6667, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_f1_score + >>> from torchmetrics.functional.classification import multiclass_f1_score >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -572,7 +573,7 @@ def multiclass_f1_score( tensor([0.6667, 0.6667, 1.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_f1_score + >>> from torchmetrics.functional.classification import multiclass_f1_score >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_f1_score(preds, target, num_classes=3, multidim_average='samplewise') @@ -659,7 +660,7 @@ def multilabel_f1_score( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_f1_score + >>> from torchmetrics.functional.classification import multilabel_f1_score >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_f1_score(preds, target, num_labels=3) @@ -668,7 +669,7 @@ def multilabel_f1_score( tensor([1.0000, 0.0000, 0.6667]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_f1_score + >>> from torchmetrics.functional.classification import multilabel_f1_score >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_f1_score(preds, target, num_labels=3) @@ -677,7 +678,7 @@ def multilabel_f1_score( tensor([1.0000, 0.0000, 0.6667]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_f1_score + >>> from torchmetrics.functional.classification import multilabel_f1_score >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -800,8 +801,20 @@ def fbeta_score( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes f_beta metric. .. math:: @@ -896,6 +909,26 @@ def fbeta_score( tensor(0.3333) """ + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_fbeta_score(preds, target, beta, threshold, **kwargs) + if task == "multiclass": + return multiclass_fbeta_score(preds, target, beta, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_fbeta_score(preds, target, beta, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = list(AvgMethod) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -936,8 +969,20 @@ def f1_score( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - """Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes F1 metric. F1 metrics correspond to a equally weighted average of the precision and recall scores. Works with binary, multiclass, and multilabel data. Accepts probabilities or logits from a model output or integer class values in prediction. @@ -1030,6 +1075,26 @@ def f1_score( >>> f1_score(preds, target, num_classes=3) tensor(0.3333) """ + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_f1_score(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_f1_score(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_f1_score(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) return fbeta_score( preds, target, 1.0, average, mdmc_average, ignore_index, num_classes, threshold, top_k, multiclass ) diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 0d63409c0e5..84d459de82d 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -33,6 +33,7 @@ ) from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.prints import rank_zero_warn def _hamming_distance_reduce( @@ -134,21 +135,21 @@ def binary_hamming_distance( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_hamming_distance + >>> from torchmetrics.functional.classification import binary_hamming_distance >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_hamming_distance(preds, target) tensor(0.3333) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_hamming_distance + >>> from torchmetrics.functional.classification import binary_hamming_distance >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_hamming_distance(preds, target) tensor(0.3333) Example (multidim tensors): - >>> from torchmetrics.functional import binary_hamming_distance + >>> from torchmetrics.functional.classification import binary_hamming_distance >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -237,7 +238,7 @@ def multiclass_hamming_distance( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_hamming_distance + >>> from torchmetrics.functional.classification import multiclass_hamming_distance >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_hamming_distance(preds, target, num_classes=3) @@ -246,7 +247,7 @@ def multiclass_hamming_distance( tensor([0.5000, 0.0000, 0.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_hamming_distance + >>> from torchmetrics.functional.classification import multiclass_hamming_distance >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -260,7 +261,7 @@ def multiclass_hamming_distance( tensor([0.5000, 0.0000, 0.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_hamming_distance + >>> from torchmetrics.functional.classification import multiclass_hamming_distance >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise') @@ -345,7 +346,7 @@ def multilabel_hamming_distance( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_hamming_distance + >>> from torchmetrics.functional.classification import multilabel_hamming_distance >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_hamming_distance(preds, target, num_labels=3) @@ -354,7 +355,7 @@ def multilabel_hamming_distance( tensor([0.0000, 0.5000, 0.5000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_hamming_distance + >>> from torchmetrics.functional.classification import multilabel_hamming_distance >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_hamming_distance(preds, target, num_labels=3) @@ -363,7 +364,7 @@ def multilabel_hamming_distance( tensor([0.0000, 0.5000, 0.5000]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_hamming_distance + >>> from torchmetrics.functional.classification import multilabel_hamming_distance >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -429,8 +430,27 @@ def _hamming_distance_compute(correct: Tensor, total: Union[int, Tensor]) -> Ten return 1 - correct.float() / total -def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> Tensor: +def hamming_distance( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + average: Optional[str] = "macro", + top_k: int = 1, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Computes the average `Hamming distance`_ (also known as Hamming loss) between targets and predictions: @@ -461,6 +481,26 @@ def hamming_distance(preds: Tensor, target: Tensor, threshold: float = 0.5) -> T >>> hamming_distance(preds, target) tensor(0.2500) """ + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_hamming_distance(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_hamming_distance(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_hamming_distance(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) correct, total = _hamming_distance_update(preds, target, threshold) return _hamming_distance_compute(correct, total) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 623900b23cc..e6b9388a3cd 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -26,6 +26,7 @@ from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import DataType, EnumStr +from torchmetrics.utilities.prints import rank_zero_warn def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -105,7 +106,7 @@ def binary_hinge_loss( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import binary_hinge_loss + >>> from torchmetrics.functional.classification import binary_hinge_loss >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) >>> target = torch.tensor([0, 0, 1, 1, 1]) >>> binary_hinge_loss(preds, target) @@ -221,7 +222,7 @@ def multiclass_hinge_loss( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import multiclass_hinge_loss + >>> from torchmetrics.functional.classification import multiclass_hinge_loss >>> preds = torch.tensor([[0.25, 0.20, 0.55], ... [0.55, 0.05, 0.40], ... [0.10, 0.30, 0.60], @@ -383,8 +384,19 @@ def hinge_loss( target: Tensor, squared: bool = False, multiclass_mode: Optional[Union[str, MulticlassMode]] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs). In the binary case it is defined as: @@ -451,5 +463,23 @@ def hinge_loss( >>> hinge_loss(preds, target, multiclass_mode="one-vs-all") tensor([2.2333, 1.5000, 1.2333]) """ + if task is not None: + kwargs = dict(ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_hinge_loss(preds, target, squared, **kwargs) + if task == "multiclass": + return multiclass_hinge_loss(preds, target, num_classes, squared, multiclass_mode, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) measure, total = _hinge_update(preds, target, squared=squared, multiclass_mode=multiclass_mode) return _hinge_compute(measure, total) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 159af3a92c8..7b14197a399 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -33,6 +33,7 @@ _multilabel_confusion_matrix_update, ) from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.prints import rank_zero_warn def _jaccard_index_reduce( @@ -122,14 +123,14 @@ def binary_jaccard_index( kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_jaccard_index + >>> from torchmetrics.functional.classification import binary_jaccard_index >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> binary_jaccard_index(preds, target) tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_jaccard_index + >>> from torchmetrics.functional.classification import binary_jaccard_index >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_jaccard_index(preds, target) @@ -158,8 +159,8 @@ def multiclass_jaccard_index( preds: Tensor, target: Tensor, num_classes: int, - ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r"""Calculates the Jaccard index for multiclass tasks. The `Jaccard index`_ (also known as @@ -180,8 +181,6 @@ def multiclass_jaccard_index( Args: num_classes: Integer specifing the number of classes - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation average: Defines the reduction that is applied over labels. Should be one of the following: @@ -190,19 +189,21 @@ def multiclass_jaccard_index( - ``weighted``: Calculates statistics for each label and computes weighted average using their support - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics.functional import multiclass_jaccard_index + >>> from torchmetrics.functional.classification import multiclass_jaccard_index >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_jaccard_index(preds, target, num_classes=3) tensor(0.6667) Example (pred is float tensor): - >>> from torchmetrics.functional import multiclass_jaccard_index + >>> from torchmetrics.functional.classification import multiclass_jaccard_index >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -238,8 +239,8 @@ def multilabel_jaccard_index( target: Tensor, num_labels: int, threshold: float = 0.5, - ignore_index: Optional[int] = None, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as @@ -261,8 +262,6 @@ def multilabel_jaccard_index( Args: num_classes: Integer specifing the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions - ignore_index: - Specifies a target value that is ignored and does not contribute to the metric calculation average: Defines the reduction that is applied over labels. Should be one of the following: @@ -271,19 +270,21 @@ def multilabel_jaccard_index( - ``weighted``: Calculates statistics for each label and computes weighted average using their support - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_jaccard_index + >>> from torchmetrics.functional.classification import multilabel_jaccard_index >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_jaccard_index(preds, target, num_labels=3) tensor(0.5000) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_jaccard_index + >>> from torchmetrics.functional.classification import multilabel_jaccard_index >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_jaccard_index(preds, target, num_labels=3) @@ -381,8 +382,19 @@ def jaccard_index( ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: - r"""Computes `Jaccard index`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Jaccard index`_ .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} @@ -441,6 +453,25 @@ def jaccard_index( >>> jaccard_index(pred, target, num_classes=2) tensor(0.9660) """ - + if task is not None: + kwargs = dict(ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_jaccard_index(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_jaccard_index(preds, target, num_classes, average, **kwargs) + if task == "multilabel": + return multilabel_jaccard_index(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold) return _jaccard_from_confmat(confmat, num_classes, average, ignore_index, absent_score) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 5f50f7fcf1d..a5774b2a8f7 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import ( _binary_confusion_matrix_arg_validation, @@ -31,6 +32,7 @@ _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) +from torchmetrics.utilities.prints import rank_zero_warn def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: @@ -90,14 +92,14 @@ def binary_matthews_corrcoef( kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_matthews_corrcoef + >>> from torchmetrics.functional.classification import binary_matthews_corrcoef >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0, 1, 0, 0]) >>> binary_matthews_corrcoef(preds, target) tensor(0.5774) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_matthews_corrcoef + >>> from torchmetrics.functional.classification import binary_matthews_corrcoef >>> target = torch.tensor([1, 1, 0, 0]) >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> binary_matthews_corrcoef(preds, target) @@ -147,14 +149,14 @@ def multiclass_matthews_corrcoef( kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): - >>> from torchmetrics.functional import multiclass_matthews_corrcoef + >>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) tensor(0.7000) Example (pred is float tensor): - >>> from torchmetrics.functional import multiclass_matthews_corrcoef + >>> from torchmetrics.functional.classification import multiclass_matthews_corrcoef >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -211,14 +213,14 @@ def multilabel_matthews_corrcoef( kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_matthews_corrcoef + >>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) tensor(0.3333) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_matthews_corrcoef + >>> from torchmetrics.functional.classification import multilabel_matthews_corrcoef >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) @@ -272,8 +274,19 @@ def matthews_corrcoef( target: Tensor, num_classes: int, threshold: float = 0.5, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Tensor: r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + Calculates `Matthews correlation coefficient`_ that measures the general correlation or quality of a classification. In the binary case it is defined as: @@ -301,5 +314,25 @@ def matthews_corrcoef( tensor(0.5774) """ + if task is not None: + kwargs = dict(ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_matthews_corrcoef(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_matthews_corrcoef(preds, target, num_classes, **kwargs) + if task == "multilabel": + return multilabel_matthews_corrcoef(preds, target, num_labels, threshold, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) confmat = _matthews_corrcoef_update(preds, target, num_classes, threshold) return _matthews_corrcoef_compute(confmat) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index d952c4c0769..b00172cee7f 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -35,6 +35,7 @@ ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn def _precision_recall_reduce( @@ -111,21 +112,21 @@ def binary_precision( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_precision + >>> from torchmetrics.functional.classification import binary_precision >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_precision(preds, target) tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_precision + >>> from torchmetrics.functional.classification import binary_precision >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_precision(preds, target) tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics.functional import binary_precision + >>> from torchmetrics.functional.classification import binary_precision >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -212,7 +213,7 @@ def multiclass_precision( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_precision + >>> from torchmetrics.functional.classification import multiclass_precision >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_precision(preds, target, num_classes=3) @@ -221,7 +222,7 @@ def multiclass_precision( tensor([1.0000, 0.5000, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_precision + >>> from torchmetrics.functional.classification import multiclass_precision >>> target = target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -235,7 +236,7 @@ def multiclass_precision( tensor([1.0000, 0.5000, 1.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_precision + >>> from torchmetrics.functional.classification import multiclass_precision >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_precision(preds, target, num_classes=3, multidim_average='samplewise') @@ -318,7 +319,7 @@ def multilabel_precision( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_precision + >>> from torchmetrics.functional.classification import multilabel_precision >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_precision(preds, target, num_labels=3) @@ -327,7 +328,7 @@ def multilabel_precision( tensor([1.0000, 0.0000, 0.5000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_precision + >>> from torchmetrics.functional.classification import multilabel_precision >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_precision(preds, target, num_labels=3) @@ -336,7 +337,7 @@ def multilabel_precision( tensor([1.0000, 0.0000, 0.5000]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_precision + >>> from torchmetrics.functional.classification import multilabel_precision >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -405,21 +406,21 @@ def binary_recall( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_recall + >>> from torchmetrics.functional.classification import binary_recall >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_recall(preds, target) tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_recall + >>> from torchmetrics.functional.classification import binary_recall >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_recall(preds, target) tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics.functional import binary_recall + >>> from torchmetrics.functional.classification import binary_recall >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -506,7 +507,7 @@ def multiclass_recall( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_recall + >>> from torchmetrics.functional.classification import multiclass_recall >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_recall(preds, target, num_classes=3) @@ -515,7 +516,7 @@ def multiclass_recall( tensor([0.5000, 1.0000, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_recall + >>> from torchmetrics.functional.classification import multiclass_recall >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -529,7 +530,7 @@ def multiclass_recall( tensor([0.5000, 1.0000, 1.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_recall + >>> from torchmetrics.functional.classification import multiclass_recall >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_recall(preds, target, num_classes=3, multidim_average='samplewise') @@ -612,7 +613,7 @@ def multilabel_recall( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_recall + >>> from torchmetrics.functional.classification import multilabel_recall >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_recall(preds, target, num_labels=3) @@ -621,7 +622,7 @@ def multilabel_recall( tensor([1., 0., 1.]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_recall + >>> from torchmetrics.functional.classification import multilabel_recall >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_recall(preds, target, num_labels=3) @@ -630,7 +631,7 @@ def multilabel_recall( tensor([1., 0., 1.]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_recall + >>> from torchmetrics.functional.classification import multilabel_recall >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -718,8 +719,20 @@ def precision( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Precision`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Precision`_ .. math:: \text{Precision} = \frac{\text{TP}}{\text{TP} + \text{FP}} @@ -819,6 +832,26 @@ def precision( tensor(0.2500) """ + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_precision(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_precision(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_precision(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") @@ -910,8 +943,20 @@ def recall( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Recall`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Recall`_ .. math:: \text{Recall} = \frac{\text{TP}}{\text{TP} + \text{FN}} @@ -1012,6 +1057,26 @@ def recall( tensor(0.2500) """ + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_recall(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_recall(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_recall(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index a49521cb685..84fc046ac7f 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -17,6 +17,7 @@ import torch from torch import Tensor, tensor from torch.nn import functional as F +from typing_extensions import Literal from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape @@ -286,7 +287,7 @@ def binary_precision_recall_curve( - thresholds: an 1d tensor of size (n_thresholds, ) with increasing threshold values Example: - >>> from torchmetrics.functional import binary_precision_recall_curve + >>> from torchmetrics.functional.classification import binary_precision_recall_curve >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_precision_recall_curve(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE @@ -503,7 +504,7 @@ def multiclass_precision_recall_curve( then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. Example: - >>> from torchmetrics.functional import multiclass_precision_recall_curve + >>> from torchmetrics.functional.classification import multiclass_precision_recall_curve >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -727,7 +728,7 @@ def multilabel_precision_recall_curve( then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. Example: - >>> from torchmetrics.functional import multilabel_precision_recall_curve + >>> from torchmetrics.functional.classification import multilabel_precision_recall_curve >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -977,8 +978,21 @@ def precision_recall_curve( num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Computes precision-recall pairs for different thresholds. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes precision-recall pairs for different thresholds. Args: preds: predictions from model (probabilities). @@ -1038,5 +1052,25 @@ def precision_recall_curve( >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ + if task is not None: + kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_precision_recall_curve(preds, target, **kwargs) + if task == "multiclass": + return multiclass_precision_recall_curve(preds, target, num_classes, **kwargs) + if task == "multilabel": + return multilabel_precision_recall_curve(preds, target, num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) return _precision_recall_curve_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/src/torchmetrics/functional/classification/ranking.py b/src/torchmetrics/functional/classification/ranking.py index b4fff78dab6..8446da4c51b 100644 --- a/src/torchmetrics/functional/classification/ranking.py +++ b/src/torchmetrics/functional/classification/ranking.py @@ -84,7 +84,7 @@ def multilabel_coverage_error( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import multilabel_coverage_error + >>> from torchmetrics.functional.classification import multilabel_coverage_error >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) @@ -155,7 +155,7 @@ def multilabel_ranking_average_precision( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import multilabel_ranking_average_precision + >>> from torchmetrics.functional.classification import multilabel_ranking_average_precision >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) @@ -237,7 +237,7 @@ def multilabel_ranking_loss( Set to ``False`` for faster computations. Example: - >>> from torchmetrics.functional import multilabel_ranking_loss + >>> from torchmetrics.functional.classification import multilabel_ranking_loss >>> _ = torch.manual_seed(42) >>> preds = torch.rand(10, 5) >>> target = torch.randint(2, (10, 5)) diff --git a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py index 81dbdc694ee..2cb4d6124b8 100644 --- a/src/torchmetrics/functional/classification/recall_at_fixed_precision.py +++ b/src/torchmetrics/functional/classification/recall_at_fixed_precision.py @@ -132,7 +132,7 @@ def binary_recall_at_fixed_precision( - threshold: an scalar tensor with the corresponding threshold level Example: - >>> from torchmetrics.functional import binary_recall_at_fixed_precision + >>> from torchmetrics.functional.classification import binary_recall_at_fixed_precision >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_recall_at_fixed_precision(preds, target, min_precision=0.5, thresholds=None) @@ -233,7 +233,7 @@ def multiclass_recall_at_fixed_precision( - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class Example: - >>> from torchmetrics.functional import multiclass_recall_at_fixed_precision + >>> from torchmetrics.functional.classification import multiclass_recall_at_fixed_precision >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -342,7 +342,7 @@ def multilabel_recall_at_fixed_precision( - thresholds: an 1d tensor of size (n_classes, ) with the corresponding threshold level per class Example: - >>> from torchmetrics.functional import multilabel_recall_at_fixed_precision + >>> from torchmetrics.functional.classification import multilabel_recall_at_fixed_precision >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index a9e6e1364d0..f064a294219 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.precision_recall_curve import ( _binary_clf_curve, @@ -135,7 +136,7 @@ def binary_roc( - thresholds: an 1d tensor of size (n_thresholds, ) with decreasing threshold values Example: - >>> from torchmetrics.functional import binary_roc + >>> from torchmetrics.functional.classification import binary_roc >>> preds = torch.tensor([0, 0.5, 0.7, 0.8]) >>> target = torch.tensor([0, 1, 1, 0]) >>> binary_roc(preds, target, thresholds=None) # doctest: +NORMALIZE_WHITESPACE @@ -242,7 +243,7 @@ def multiclass_roc( then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all classes. Example: - >>> from torchmetrics.functional import multiclass_roc + >>> from torchmetrics.functional.classification import multiclass_roc >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], ... [0.05, 0.75, 0.05, 0.05, 0.05], ... [0.05, 0.05, 0.75, 0.05, 0.05], @@ -378,7 +379,7 @@ def multilabel_roc( then a single 1d tensor of size (n_thresholds, ) is returned with shared threshold values for all labels. Example: - >>> from torchmetrics.functional import multilabel_roc + >>> from torchmetrics.functional.classification import multilabel_roc >>> preds = torch.tensor([[0.75, 0.05, 0.35], ... [0.45, 0.75, 0.05], ... [0.05, 0.55, 0.75], @@ -603,8 +604,21 @@ def roc( num_classes: Optional[int] = None, pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: - """Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes the Receiver Operating Characteristic (ROC). Works with both binary, multiclass and multilabel input. .. note:: @@ -681,5 +695,25 @@ def roc( tensor([1.7576, 0.7576, 0.3680, 0.3468, 0.0745]), tensor([1.1837, 0.1837, 0.1338, 0.1183, 0.1138])] """ + if task is not None: + kwargs = dict(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_roc(preds, target, **kwargs) + if task == "multiclass": + return multiclass_roc(preds, target, num_classes, **kwargs) + if task == "multilabel": + return multilabel_roc(preds, target, num_labels, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) preds, target, num_classes, pos_label = _roc_update(preds, target, num_classes, pos_label) return _roc_compute(preds, target, num_classes, pos_label, sample_weights) diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 5eadad347e6..5893c2fcea3 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -35,6 +35,7 @@ ) from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn def _specificity_reduce( @@ -108,21 +109,21 @@ def binary_specificity( is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. Example (preds is int tensor): - >>> from torchmetrics.functional import binary_specificity + >>> from torchmetrics.functional.classification import binary_specificity >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_specificity(preds, target) tensor(0.6667) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_specificity + >>> from torchmetrics.functional.classification import binary_specificity >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_specificity(preds, target) tensor(0.6667) Example (multidim tensors): - >>> from torchmetrics.functional import binary_specificity + >>> from torchmetrics.functional.classification import binary_specificity >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -209,7 +210,7 @@ def multiclass_specificity( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_specificity + >>> from torchmetrics.functional.classification import multiclass_specificity >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_specificity(preds, target, num_classes=3) @@ -218,7 +219,7 @@ def multiclass_specificity( tensor([1.0000, 0.6667, 1.0000]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_specificity + >>> from torchmetrics.functional.classification import multiclass_specificity >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -232,7 +233,7 @@ def multiclass_specificity( tensor([1.0000, 0.6667, 1.0000]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_specificity + >>> from torchmetrics.functional.classification import multiclass_specificity >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise') @@ -315,7 +316,7 @@ def multilabel_specificity( - If ``average=None/'none'``, the shape will be ``(N, C)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_specificity + >>> from torchmetrics.functional.classification import multilabel_specificity >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_specificity(preds, target, num_labels=3) @@ -324,7 +325,7 @@ def multilabel_specificity( tensor([1., 1., 0.]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_specificity + >>> from torchmetrics.functional.classification import multilabel_specificity >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_specificity(preds, target, num_labels=3) @@ -333,7 +334,7 @@ def multilabel_specificity( tensor([1., 1., 0.]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_specificity + >>> from torchmetrics.functional.classification import multilabel_specificity >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -415,8 +416,20 @@ def specificity( threshold: float = 0.5, top_k: Optional[int] = None, multiclass: Optional[bool] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes `Specificity`_ + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + Computes `Specificity`_ .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} @@ -515,7 +528,26 @@ def specificity( >>> specificity(preds, target, average='micro') tensor(0.6250) """ - + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_specificity(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_specificity(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_specificity(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 13739ab6a1c..f68e1de8080 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -183,21 +183,21 @@ def binary_stat_scores( - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` Example (preds is int tensor): - >>> from torchmetrics.functional import binary_stat_scores + >>> from torchmetrics.functional.classification import binary_stat_scores >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) >>> binary_stat_scores(preds, target) tensor([2, 1, 2, 1, 3]) Example (preds is float tensor): - >>> from torchmetrics.functional import binary_stat_scores + >>> from torchmetrics.functional.classification import binary_stat_scores >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) >>> binary_stat_scores(preds, target) tensor([2, 1, 2, 1, 3]) Example (multidim tensors): - >>> from torchmetrics.functional import binary_stat_scores + >>> from torchmetrics.functional.classification import binary_stat_scores >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -506,7 +506,7 @@ def multiclass_stat_scores( - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multiclass_stat_scores + >>> from torchmetrics.functional.classification import multiclass_stat_scores >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([2, 1, 0, 1]) >>> multiclass_stat_scores(preds, target, num_classes=3, average='micro') @@ -517,7 +517,7 @@ def multiclass_stat_scores( [1, 0, 3, 0, 1]]) Example (preds is float tensor): - >>> from torchmetrics.functional import multiclass_stat_scores + >>> from torchmetrics.functional.classification import multiclass_stat_scores >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ ... [0.16, 0.26, 0.58], @@ -533,7 +533,7 @@ def multiclass_stat_scores( [1, 0, 3, 0, 1]]) Example (multidim tensors): - >>> from torchmetrics.functional import multiclass_stat_scores + >>> from torchmetrics.functional.classification import multiclass_stat_scores >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average='micro') @@ -761,7 +761,7 @@ def multilabel_stat_scores( - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` Example (preds is int tensor): - >>> from torchmetrics.functional import multilabel_stat_scores + >>> from torchmetrics.functional.classification import multilabel_stat_scores >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) >>> multilabel_stat_scores(preds, target, num_labels=3, average='micro') @@ -772,7 +772,7 @@ def multilabel_stat_scores( [1, 1, 0, 0, 1]]) Example (preds is float tensor): - >>> from torchmetrics.functional import multilabel_stat_scores + >>> from torchmetrics.functional.classification import multilabel_stat_scores >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) >>> multilabel_stat_scores(preds, target, num_labels=3, average='micro') @@ -783,7 +783,7 @@ def multilabel_stat_scores( [1, 1, 0, 0, 1]]) Example (multidim tensors): - >>> from torchmetrics.functional import multilabel_stat_scores + >>> from torchmetrics.functional.classification import multilabel_stat_scores >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) >>> preds = torch.tensor( ... [ @@ -1093,8 +1093,22 @@ def stat_scores( threshold: float = 0.5, multiclass: Optional[bool] = None, ignore_index: Optional[int] = None, + task: Optional[Literal["binary", "multiclass", "multilabel"]] = None, + num_labels: Optional[int] = None, + average: Optional[str] = "micro", + multidim_average: Optional[Literal["global", "samplewise"]] = "global", + validate_args: bool = True, ) -> Tensor: - r"""Computes the number of true positives, false positives, true negatives, false negatives. + r""" + .. note:: + From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification + metric. Moving forward we recommend using these versions. This base metric will still work as it did + prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required + and the general order of arguments may change, such that this metric will just function as an single + entrypoint to calling the three specialized versions. + + + Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors`_ and the `confusion matrix`_. The reduction method (how the statistics are aggregated) is controlled by the @@ -1210,13 +1224,26 @@ def stat_scores( tensor([2, 2, 6, 2, 4]) """ - rank_zero_warn( - "`torchmetrics.functional.stat_scores` have been deprecated in v0.10 in favor of" - "`torchmetrics.functional.binary_stat_scores`, `torchmetrics.functional.multiclass_stat_scores`" - "and `torchmetrics.functional.multilabel_stat_scores`. Please upgrade to the version that matches" - "your problem (API may have changed). This function will be removed v0.11.", - DeprecationWarning, - ) + if task is not None: + kwargs = dict(multidim_average=multidim_average, ignore_index=ignore_index, validate_args=validate_args) + if task == "binary": + return binary_stat_scores(preds, target, threshold, **kwargs) + if task == "multiclass": + return multiclass_stat_scores(preds, target, num_classes, average, top_k, **kwargs) + if task == "multilabel": + return multilabel_stat_scores(preds, target, num_labels, threshold, average, **kwargs) + raise ValueError( + f"Expected argument `task` to either be `'binary'`, `'multiclass'` or `'multilabel'` but got {task}" + ) + else: + rank_zero_warn( + "From v0.10 an `'binary_*'`, `'multiclass_*', `'multilabel_*'` version now exist of each classification" + " metric. Moving forward we recommend using these versions. This base metric will still work as it did" + " prior to v0.10 until v0.11. From v0.11 the `task` argument introduced in this metric will be required" + " and the general order of arguments may change, such that this metric will just function as an single" + " entrypoint to calling the three specialized versions.", + DeprecationWarning, + ) if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index e9486dc3577..32e1f0c4312 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -392,10 +392,9 @@ def wrapped_func(*args: Any, **kwargs: Any) -> None: except RuntimeError as err: if "Expected all tensors to be on" in str(err): raise RuntimeError( - "Encountered different devices in metric calculation" - " (see stacktrace for details)." - "This could be due to the metric class not being on the same device as input." - f"Instead of `metric={self.__class__.__name__}(...)` try to do" + "Encountered different devices in metric calculation (see stacktrace for details)." + " This could be due to the metric class not being on the same device as input." + f" Instead of `metric={self.__class__.__name__}(...)` try to do" f" `metric={self.__class__.__name__}(...).to(device)` where" " device corresponds to the device of the input." ) from err diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 44f32874819..accaa395fad 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -480,413 +480,3 @@ def test_multilabel_accuracy_half_gpu(self, input, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_accuracy(preds, target, subset_accuracy): -# sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) -# sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - -# if mode == DataType.MULTIDIM_MULTICLASS and not subset_accuracy: -# sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) -# sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) -# elif mode == DataType.MULTIDIM_MULTICLASS and subset_accuracy: -# return np.all(sk_preds == sk_target, axis=(1, 2)).mean() -# elif mode == DataType.MULTILABEL and not subset_accuracy: -# sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) - -# return sk_accuracy(y_true=sk_target, y_pred=sk_preds) - - -# @pytest.mark.parametrize( -# "preds, target, subset_accuracy, mdmc_average", -# [ -# (_input_binary_logits.preds, _input_binary_logits.target, False, None), -# (_input_binary_prob.preds, _input_binary_prob.target, False, None), -# (_input_binary.preds, _input_binary.target, False, None), -# (_input_mlb_prob.preds, _input_mlb_prob.target, True, None), -# (_input_mlb_logits.preds, _input_mlb_logits.target, False, None), -# (_input_mlb_prob.preds, _input_mlb_prob.target, False, None), -# (_input_mlb.preds, _input_mlb.target, True, None), -# (_input_mlb.preds, _input_mlb.target, False, "global"), -# (_input_mcls_prob.preds, _input_mcls_prob.target, False, None), -# (_input_mcls_logits.preds, _input_mcls_logits.target, False, None), -# (_input_mcls.preds, _input_mcls.target, False, None), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, False, "global"), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, True, None), -# (_input_mdmc.preds, _input_mdmc.target, False, "global"), -# (_input_mdmc.preds, _input_mdmc.target, True, None), -# (_input_mlmd_prob.preds, _input_mlmd_prob.target, True, None), -# (_input_mlmd_prob.preds, _input_mlmd_prob.target, False, None), -# (_input_mlmd.preds, _input_mlmd.target, True, None), -# (_input_mlmd.preds, _input_mlmd.target, False, "global"), -# ], -# ) -# class TestAccuracies(MetricTester): -# @pytest.mark.parametrize("ddp", [False, True]) -# @pytest.mark.parametrize("dist_sync_on_step", [False, True]) -# def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy, mdmc_average): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=Accuracy, -# sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy, "mdmc_average": mdmc_average}, -# ) - -# def test_accuracy_fn(self, preds, target, subset_accuracy, mdmc_average): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=accuracy, -# sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), -# metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, -# ) - -# def test_accuracy_differentiability(self, preds, target, subset_accuracy, mdmc_average): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=Accuracy, -# metric_functional=accuracy, -# metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy, "mdmc_average": mdmc_average}, -# ) - - -# _l1to4 = [0.1, 0.2, 0.3, 0.4] -# _l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) -# _l1to4t3_mcls = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T] - -# # The preds in these examples always put highest probability on class 3, second highest on class 2, -# # third highest on class 1, and lowest on class 0 -# _topk_preds_mcls = tensor([_l1to4t3, _l1to4t3]).float() -# _topk_target_mcls = tensor([[1, 2, 3], [2, 1, 0]]) - -# # This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) -# _topk_preds_mdmc = tensor([_l1to4t3_mcls, _l1to4t3_mcls]).float() -# _topk_target_mdmc = tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) - -# # Multilabel -# _ml_t1 = [0.8, 0.2, 0.8, 0.2] -# _ml_t2 = [_ml_t1, _ml_t1] -# _ml_ta2 = [[1, 0, 1, 1], [0, 1, 1, 0]] -# _av_preds_ml = tensor([_ml_t2, _ml_t2]).float() -# _av_target_ml = tensor([_ml_ta2, _ml_ta2]) - -# # Inputs with negative target values to be ignored -# Input = namedtuple("Input", ["preds", "target", "ignore_index", "result"]) -# _binary_with_neg_tgt = Input( -# preds=torch.tensor([0, 1, 0]), target=torch.tensor([0, 1, -1]), ignore_index=-1, result=torch.tensor(1.0) -# ) -# _multiclass_logits_with_neg_tgt = Input( -# preds=torch.tensor([[0.8, 0.1], [0.2, 0.7], [0.5, 0.5]]), -# target=torch.tensor([0, 1, -1]), -# ignore_index=-1, -# result=torch.tensor(1.0), -# ) -# _multidim_multiclass_with_neg_tgt = Input( -# preds=torch.tensor([[0, 0], [1, 1], [0, 0]]), -# target=torch.tensor([[0, 0], [-1, 1], [1, -1]]), -# ignore_index=-1, -# result=torch.tensor(0.75), -# ) -# _multidim_multiclass_logits_with_neg_tgt = Input( -# preds=torch.tensor([[[0.8, 0.7], [0.2, 0.4]], [[0.1, 0.2], [0.9, 0.8]], [[0.7, 0.9], [0.2, 0.4]]]), -# target=torch.tensor([[0, 0], [-1, 1], [1, -1]]), -# ignore_index=-1, -# result=torch.tensor(0.75), -# ) - - -# # Replace with a proper sk_metric test once sklearn 0.24 hits :) -# @pytest.mark.parametrize( -# "preds, target, exp_result, k, subset_accuracy", -# [ -# (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, False), -# (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, False), -# (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, False), -# (_topk_preds_mcls, _topk_target_mcls, 1 / 6, 1, True), -# (_topk_preds_mcls, _topk_target_mcls, 3 / 6, 2, True), -# (_topk_preds_mcls, _topk_target_mcls, 5 / 6, 3, True), -# (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False), -# (_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False), -# (_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False), -# (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), -# (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), -# (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), -# (_av_preds_ml, _av_target_ml, 5 / 8, None, False), -# (_av_preds_ml, _av_target_ml, 0, None, True), -# ], -# ) -# def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): -# topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy, mdmc_average="global") - -# for batch in range(preds.shape[0]): -# topk(preds[batch], target[batch]) - -# assert topk.compute() == exp_result - -# # Test functional -# total_samples = target.shape[0] * target.shape[1] - -# preds = preds.view(total_samples, 4, -1) -# target = target.view(total_samples, -1) - -# assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result - - -# # Only MC and MDMC with probs input type should be accepted for top_k -# @pytest.mark.parametrize( -# "preds, target", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target), -# (_input_binary.preds, _input_binary.target), -# (_input_mlb_prob.preds, _input_mlb_prob.target), -# (_input_mlb.preds, _input_mlb.target), -# (_input_mcls.preds, _input_mcls.target), -# (_input_mdmc.preds, _input_mdmc.target), -# (_input_mlmd_prob.preds, _input_mlmd_prob.target), -# (_input_mlmd.preds, _input_mlmd.target), -# ], -# ) -# def test_topk_accuracy_wrong_input_types(preds, target): -# topk = Accuracy(top_k=1) - -# with pytest.raises(ValueError): -# topk(preds[0], target[0]) - -# with pytest.raises(ValueError): -# accuracy(preds[0], target[0], top_k=1) - - -# @pytest.mark.parametrize( -# "average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold", -# [ -# ("unknown", None, None, _input_binary, None, None, 0.5), -# ("micro", "unknown", None, _input_binary, None, None, 0.5), -# ("macro", None, None, _input_binary, None, None, 0.5), -# ("micro", None, None, _input_mdmc_prob, None, None, 0.5), -# ("micro", None, None, _input_binary_prob, 0, None, 0.5), -# ("micro", None, None, _input_mcls_prob, NUM_CLASSES, None, 0.5), -# ("micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES, None, 0.5), -# (None, None, None, _input_mcls_prob, None, 0, 0.5), -# (None, None, None, _input_mcls_prob, None, None, 1.5), -# ], -# ) -# def test_wrong_params(average, mdmc_average, num_classes, inputs, ignore_index, top_k, threshold): -# preds, target = inputs.preds, inputs.target - -# with pytest.raises(ValueError): -# acc = Accuracy( -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# threshold=threshold, -# top_k=top_k, -# ) -# acc(preds[0], target[0]) -# acc.compute() - -# with pytest.raises(ValueError): -# accuracy( -# preds[0], -# target[0], -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# threshold=threshold, -# top_k=top_k, -# ) - - -# @pytest.mark.parametrize( -# "preds_mc, target_mc, preds_ml, target_ml", -# [ -# ( -# tensor([0, 1, 1, 1]), -# tensor([2, 2, 1, 1]), -# tensor([[0.8, 0.2, 0.8, 0.7], [0.6, 0.4, 0.6, 0.5]]), -# tensor([[1, 0, 1, 1], [0, 0, 1, 0]]), -# ) -# ], -# ) -# def test_different_modes(preds_mc, target_mc, preds_ml, target_ml): -# acc = Accuracy() -# acc(preds_mc, target_mc) -# with pytest.raises(ValueError, match="^[You cannot use]"): -# acc(preds_ml, target_ml) - - -# _bin_t1 = [0.7, 0.6, 0.2, 0.1] -# _av_preds_bin = tensor([_bin_t1, _bin_t1]).float() -# _av_target_bin = tensor([[1, 0, 0, 0], [0, 1, 1, 0]]) - - -# @pytest.mark.parametrize( -# "preds, target, num_classes, exp_result, average, mdmc_average", -# [ -# (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 4, "macro", None), -# (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "weighted", None), -# (_topk_preds_mcls, _topk_target_mcls, 4, [0.0, 0.0, 0.0, 1.0], "none", None), -# (_topk_preds_mcls, _topk_target_mcls, 4, 1 / 6, "samples", None), -# (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 24, "macro", "samplewise"), -# (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "weighted", "samplewise"), -# (_topk_preds_mdmc, _topk_target_mdmc, 4, [0.0, 0.0, 0.0, 1 / 6], "none", "samplewise"), -# (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "samplewise"), -# (_topk_preds_mdmc, _topk_target_mdmc, 4, 1 / 6, "samples", "global"), -# (_av_preds_ml, _av_target_ml, 4, 5 / 8, "macro", None), -# (_av_preds_ml, _av_target_ml, 4, 0.70000005, "weighted", None), -# (_av_preds_ml, _av_target_ml, 4, [1 / 2, 1 / 2, 1.0, 1 / 2], "none", None), -# (_av_preds_ml, _av_target_ml, 4, 5 / 8, "samples", None), -# ], -# ) -# def test_average_accuracy(preds, target, num_classes, exp_result, average, mdmc_average): -# acc = Accuracy(num_classes=num_classes, average=average, mdmc_average=mdmc_average) - -# for batch in range(preds.shape[0]): -# acc(preds[batch], target[batch]) - -# assert (acc.compute() == tensor(exp_result)).all() - -# # Test functional -# total_samples = target.shape[0] * target.shape[1] - -# preds = preds.view(total_samples, num_classes, -1) -# target = target.view(total_samples, -1) - -# acc_score = accuracy(preds, target, num_classes=num_classes, average=average, mdmc_average=mdmc_average) -# assert (acc_score == tensor(exp_result)).all() - - -# @pytest.mark.parametrize( -# "preds, target, num_classes, exp_result, average, multiclass", -# [ -# (_av_preds_bin, _av_target_bin, 2, 19 / 30, "macro", True), -# (_av_preds_bin, _av_target_bin, 2, 5 / 8, "weighted", True), -# (_av_preds_bin, _av_target_bin, 2, [3 / 5, 2 / 3], "none", True), -# (_av_preds_bin, _av_target_bin, 2, 5 / 8, "samples", True), -# ], -# ) -# def test_average_accuracy_bin(preds, target, num_classes, exp_result, average, multiclass): -# acc = Accuracy(num_classes=num_classes, average=average, multiclass=multiclass) - -# for batch in range(preds.shape[0]): -# acc(preds[batch], target[batch]) - -# assert (acc.compute() == tensor(exp_result)).all() - -# # Test functional -# total_samples = target.shape[0] * target.shape[1] - -# preds = preds.view(total_samples, -1) -# target = target.view(total_samples, -1) -# acc_score = accuracy(preds, target, num_classes=num_classes, average=average, multiclass=multiclass) -# assert (acc_score == tensor(exp_result)).all() - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Accuracy, accuracy)]) -# @pytest.mark.parametrize( -# "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -# ) -# def test_class_not_present(metric_class, metric_fn, ignore_index, expected): -# """This tests that when metric is computed per class and a given class is not present in both the `preds` and -# `target`, the resulting score is `nan`.""" -# preds = torch.tensor([0, 0, 0]) -# target = torch.tensor([0, 0, 0]) -# num_classes = 2 - -# # test functional -# result_fn = metric_fn( -# preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index -# ) -# assert torch.allclose(expected, result_fn, equal_nan=True) - -# # test class -# cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) -# cl_metric(preds, target) -# result_cl = cl_metric.compute() -# assert torch.allclose(expected, result_cl, equal_nan=True) - - -# @pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) -# def test_same_input(average): -# preds = _input_miss_class.preds -# target = _input_miss_class.target -# preds_flat = torch.cat(list(preds), dim=0) -# target_flat = torch.cat(list(target), dim=0) - -# mc = Accuracy(num_classes=NUM_CLASSES, average=average) -# for i in range(NUM_BATCHES): -# mc.update(preds[i], target[i]) -# class_res = mc.compute() -# func_res = accuracy(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) -# sk_res = sk_accuracy(target_flat, preds_flat) - -# assert torch.allclose(class_res, torch.tensor(sk_res).float()) -# assert torch.allclose(func_res, torch.tensor(sk_res).float()) - - -# @pytest.mark.parametrize( -# "preds, target, ignore_index, result", -# [ -# ( -# _binary_with_neg_tgt.preds, -# _binary_with_neg_tgt.target, -# _binary_with_neg_tgt.ignore_index, -# _binary_with_neg_tgt.result, -# ), -# ( -# _multiclass_logits_with_neg_tgt.preds, -# _multiclass_logits_with_neg_tgt.target, -# _multiclass_logits_with_neg_tgt.ignore_index, -# _multiclass_logits_with_neg_tgt.result, -# ), -# ( -# _multidim_multiclass_with_neg_tgt.preds, -# _multidim_multiclass_with_neg_tgt.target, -# _multidim_multiclass_with_neg_tgt.ignore_index, -# _multidim_multiclass_with_neg_tgt.result, -# ), -# ( -# _multidim_multiclass_logits_with_neg_tgt.preds, -# _multidim_multiclass_logits_with_neg_tgt.target, -# _multidim_multiclass_logits_with_neg_tgt.ignore_index, -# _multidim_multiclass_logits_with_neg_tgt.result, -# ), -# ], -# ) -# def test_negative_ignore_index(preds, target, ignore_index, result): -# # We deduct -1 for an ignored index -# num_classes = len(target.unique()) - 1 - -# # Test class -# acc = Accuracy(num_classes=num_classes, ignore_index=ignore_index) -# acc_score = acc(preds, target) -# assert torch.allclose(acc_score, result) -# # Test functional metrics -# acc_score = accuracy(preds, target, num_classes=num_classes, ignore_index=ignore_index) -# assert torch.allclose(acc_score, result) - -# # If the ignore index is not set properly, we expect to see an error -# ignore_index = None -# # Test class -# acc = Accuracy(num_classes=num_classes, ignore_index=ignore_index) -# with pytest.raises(ValueError, match="^[The `target` has to be a non-negative tensor.]"): -# acc_score = acc(preds, target) - -# # Test functional -# with pytest.raises(ValueError, match="^[The `target` has to be a non-negative tensor.]"): -# acc_score = accuracy(preds, target, num_classes=num_classes, ignore_index=ignore_index) - - -# def test_negmetric_noneavg(noneavg=_negmetric_noneavg): -# acc = MetricWrapper(Accuracy(average="none", num_classes=noneavg["pred1"].shape[1])) -# result1 = acc(noneavg["pred1"], noneavg["target1"]) -# assert torch.allclose(noneavg["res1"], result1, equal_nan=True) -# result2 = acc(noneavg["pred2"], noneavg["target2"]) -# assert torch.allclose(noneavg["res2"], result2, equal_nan=True) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 5b33b649802..45b47c561f7 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -359,215 +359,3 @@ def test_multilabel_auroc_threshold_arg(self, input, average): pred, true, num_labels=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) ) assert torch.allclose(ap1, ap2) - - -# -------------------------- Old stuff -------------------------- - - -# def _sk_auroc_binary_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): -# # todo: `multi_class` is unused -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() -# return sk_roc_auc_score(y_true=sk_target, y_score=sk_preds, average=average, max_fpr=max_fpr) - - -# def _sk_auroc_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): -# sk_preds = preds.reshape(-1, num_classes).numpy() -# sk_target = target.view(-1).numpy() -# return sk_roc_auc_score( -# y_true=sk_target, -# y_score=sk_preds, -# average=average, -# max_fpr=max_fpr, -# multi_class=multi_class, -# ) - - -# def _sk_auroc_multidim_multiclass_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): -# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# sk_target = target.view(-1).numpy() -# return sk_roc_auc_score( -# y_true=sk_target, -# y_score=sk_preds, -# average=average, -# max_fpr=max_fpr, -# multi_class=multi_class, -# ) - - -# def _sk_auroc_multilabel_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): -# sk_preds = preds.reshape(-1, num_classes).numpy() -# sk_target = target.reshape(-1, num_classes).numpy() -# return sk_roc_auc_score( -# y_true=sk_target, -# y_score=sk_preds, -# average=average, -# max_fpr=max_fpr, -# multi_class=multi_class, -# ) - - -# def _sk_auroc_multilabel_multidim_prob(preds, target, num_classes, average="macro", max_fpr=None, multi_class="ovr"): -# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# return sk_roc_auc_score( -# y_true=sk_target, -# y_score=sk_preds, -# average=average, -# max_fpr=max_fpr, -# multi_class=multi_class, -# ) - - -# @pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_auroc_binary_prob, 1), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_auroc_multiclass_prob, NUM_CLASSES), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_auroc_multidim_multiclass_prob, NUM_CLASSES), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_auroc_multilabel_prob, NUM_CLASSES), -# (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_auroc_multilabel_multidim_prob, NUM_CLASSES), -# ], -# ) -# class TestAUROC(MetricTester): -# @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_auroc(self, preds, target, sk_metric, num_classes, average, max_fpr, ddp, dist_sync_on_step): -# # max_fpr different from None is not support in multi class -# if max_fpr is not None and num_classes != 1: -# pytest.skip("max_fpr parameter not support for multi class or multi label") - -# # max_fpr only supported for torch v1.6 or higher -# if max_fpr is not None and _TORCH_LOWER_1_6: -# pytest.skip("requires torch v1.6 or higher to test max_fpr argument") - -# # average='micro' only supported for multilabel -# if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: -# pytest.skip("micro argument only support for multilabel input") - -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=AUROC, -# sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, -# ) - -# @pytest.mark.parametrize("average", ["macro", "weighted", "micro"]) -# def test_auroc_functional(self, preds, target, sk_metric, num_classes, average, max_fpr): -# # max_fpr different from None is not support in multi class -# if max_fpr is not None and num_classes != 1: -# pytest.skip("max_fpr parameter not support for multi class or multi label") - -# # max_fpr only supported for torch v1.6 or higher -# if max_fpr is not None and _TORCH_LOWER_1_6: -# pytest.skip("requires torch v1.6 or higher to test max_fpr argument") - -# # average='micro' only supported for multilabel -# if average == "micro" and preds.ndim > 2 and preds.ndim == target.ndim + 1: -# pytest.skip("micro argument only support for multilabel input") - -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=auroc, -# sk_metric=partial(sk_metric, num_classes=num_classes, average=average, max_fpr=max_fpr), -# metric_args={"num_classes": num_classes, "average": average, "max_fpr": max_fpr}, -# ) - -# def test_auroc_differentiability(self, preds, target, sk_metric, num_classes, max_fpr): -# # max_fpr different from None is not support in multi class -# if max_fpr is not None and num_classes != 1: -# pytest.skip("max_fpr parameter not support for multi class or multi label") - -# # max_fpr only supported for torch v1.6 or higher -# if max_fpr is not None and _TORCH_LOWER_1_6: -# pytest.skip("requires torch v1.6 or higher to test max_fpr argument") - -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=AUROC, -# metric_functional=auroc, -# metric_args={"num_classes": num_classes, "max_fpr": max_fpr}, -# ) - - -# def test_error_on_different_mode(): -# """test that an error is raised if the user pass in data of different modes (binary, multi-label, multi- -# class)""" -# metric = AUROC() -# # pass in multi-class data -# metric.update(torch.randn(10, 5).softmax(dim=-1), torch.randint(0, 5, (10,))) -# with pytest.raises(ValueError, match=r"The mode of data.* should be constant.*"): -# # pass in multi-label data -# metric.update(torch.rand(10, 5), torch.randint(0, 2, (10, 5))) - - -# def test_error_multiclass_no_num_classes(): -# with pytest.raises( -# ValueError, match="Detected input to `multiclass` but you did not provide `num_classes` argument" -# ): -# _ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20,))) - - -# @pytest.mark.parametrize("device", ["cpu", "cuda"]) -# def test_weighted_with_empty_classes(device): -# """Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists. - -# Tests that the proper warnings and errors are raised -# """ -# if not torch.cuda.is_available() and device == "cuda": -# pytest.skip("Test requires gpu to run") - -# preds = torch.tensor( -# [ -# [0.90, 0.05, 0.05], -# [0.05, 0.90, 0.05], -# [0.05, 0.05, 0.90], -# [0.85, 0.05, 0.10], -# [0.10, 0.10, 0.80], -# ] -# ).to(device) -# target = torch.tensor([0, 1, 1, 2, 2]).to(device) -# num_classes = 3 -# _auroc = auroc(preds, target, average="weighted", num_classes=num_classes) - -# # Add in a class with zero observations at second to last index -# preds = torch.cat( -# (preds[:, : num_classes - 1], torch.rand_like(preds[:, 0:1]), preds[:, num_classes - 1 :]), axis=1 -# ) -# # Last class (2) gets moved to 3 -# target[target == num_classes - 1] = num_classes -# with pytest.warns(UserWarning, match="Class 2 had 0 observations, omitted from AUROC calculation"): -# _auroc_empty_class = auroc(preds, target, average="weighted", num_classes=num_classes + 1) -# assert _auroc == _auroc_empty_class - -# target = torch.zeros_like(target) -# with pytest.raises(ValueError, match="Found 1 non-empty class in `multiclass` AUROC calculation"): -# _ = auroc(preds, target, average="weighted", num_classes=num_classes + 1) - - -# def test_warnings_on_missing_class(): -# """Test that a warning is given if either the positive or negative class is missing.""" -# metric = AUROC() -# # no positive samples -# warning = ( -# "No positive samples in targets, true positive value should be meaningless." -# " Returning zero tensor in true positive score" -# ) -# with pytest.warns(UserWarning, match=warning): -# score = metric(torch.randn(10).sigmoid(), torch.zeros(10).int()) -# assert score == 0 - -# warning = ( -# "No negative samples in targets, false positive value should be meaningless." -# " Returning zero tensor in false positive score" -# ) -# with pytest.warns(UserWarning, match=warning): -# score = metric(torch.randn(10).sigmoid(), torch.ones(10).int()) -# assert score == 0 diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index bcc3d6d43d6..8a331b214da 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -363,153 +363,3 @@ def test_multilabel_average_precision_threshold_arg(self, input, average): pred, true, num_labels=NUM_CLASSES, average=average, thresholds=torch.linspace(0, 1, 100) ) assert torch.allclose(ap1, ap2) - - -# -------------------------- Old stuff -------------------------- - - -# def _sk_average_precision_score(y_true, probas_pred, num_classes=1, average=None): -# if num_classes == 1: -# return sk_average_precision_score(y_true, probas_pred) - -# res = [] -# for i in range(num_classes): -# y_true_temp = np.zeros_like(y_true) -# y_true_temp[y_true == i] = 1 -# res.append(sk_average_precision_score(y_true_temp, probas_pred[:, i])) - -# if average == "macro": -# return np.array(res).mean() -# if average == "weighted": -# weights = np.bincount(y_true) if y_true.max() > 1 else y_true.sum(axis=0) -# weights = weights / sum(weights) -# return (np.array(res) * weights).sum() - -# return res - - -# def _sk_avg_prec_binary_prob(preds, target, num_classes=1, average=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return _sk_average_precision_score( -# y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average -# ) - - -# def _sk_avg_prec_multiclass_prob(preds, target, num_classes=1, average=None): -# sk_preds = preds.reshape(-1, num_classes).numpy() -# sk_target = target.view(-1).numpy() - -# return _sk_average_precision_score( -# y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average -# ) - - -# def _sk_avg_prec_multilabel_prob(preds, target, num_classes=1, average=None): -# sk_preds = preds.reshape(-1, num_classes).numpy() -# sk_target = target.view(-1, num_classes).numpy() -# return sk_average_precision_score(sk_target, sk_preds, average=average) - - -# def _sk_avg_prec_multidim_multiclass_prob(preds, target, num_classes=1, average=None): -# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# sk_target = target.view(-1).numpy() -# return _sk_average_precision_score( -# y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, average=average -# ) - - -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_avg_prec_binary_prob, 1), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_avg_prec_multiclass_prob, NUM_CLASSES), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_avg_prec_multidim_multiclass_prob, NUM_CLASSES), -# (_input_multilabel.preds, _input_multilabel.target, _sk_avg_prec_multilabel_prob, NUM_CLASSES), -# ], -# ) -# class TestAveragePrecision(MetricTester): -# @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_average_precision(self, preds, target, sk_metric, num_classes, average, ddp, dist_sync_on_step): -# if target.max() > 1 and average == "micro": -# pytest.skip("average=micro and multiclass input cannot be used together") - -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=AveragePrecision, -# sk_metric=partial(sk_metric, num_classes=num_classes, average=average), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"num_classes": num_classes, "average": average}, -# ) - -# @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) -# def test_average_precision_functional(self, preds, target, sk_metric, num_classes, average): -# if target.max() > 1 and average == "micro": -# pytest.skip("average=micro and multiclass input cannot be used together") - -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=average_precision, -# sk_metric=partial(sk_metric, num_classes=num_classes, average=average), -# metric_args={"num_classes": num_classes, "average": average}, -# ) - -# def test_average_precision_differentiability(self, preds, sk_metric, target, num_classes): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=AveragePrecision, -# metric_functional=average_precision, -# metric_args={"num_classes": num_classes}, -# ) - - -# @pytest.mark.parametrize( -# ["scores", "target", "expected_score"], -# [ -# # Check the average_precision_score of a constant predictor is -# # the TPR -# # Generate a dataset with 25% of positives -# # And a constant score -# # The precision is then the fraction of positive whatever the recall -# # is, as there is only one threshold: -# (tensor([1, 1, 1, 1]), tensor([0, 0, 0, 1]), 0.25), -# # With threshold 0.8 : 1 TP and 2 TN and one FN -# (tensor([0.6, 0.7, 0.8, 9]), tensor([1, 0, 0, 1]), 0.75), -# ], -# ) -# def test_average_precision(scores, target, expected_score): -# assert average_precision(scores, target) == expected_score - - -# def test_average_precision_warnings_and_errors(): -# """Test that the correct errors and warnings gets raised.""" - -# # check average argument -# with pytest.raises(ValueError, match="Expected argument `average` to be one .*"): -# AveragePrecision(num_classes=5, average="samples") - -# # check that micro average cannot be used with multilabel input -# pred = tensor( -# [ -# [0.75, 0.05, 0.05, 0.05, 0.05], -# [0.05, 0.75, 0.05, 0.05, 0.05], -# [0.05, 0.05, 0.75, 0.05, 0.05], -# [0.05, 0.05, 0.05, 0.75, 0.05], -# ] -# ) -# target = tensor([0, 1, 3, 2]) -# average_precision = AveragePrecision(num_classes=5, average="micro") -# with pytest.raises(ValueError, match="Cannot use `micro` average with multi-class input"): -# average_precision(pred, target) - -# # check that warning is thrown when average=macro and nan is encoutered in individual scores -# average_precision = AveragePrecision(num_classes=5, average="macro") -# with pytest.warns(UserWarning, match="Average precision score for one or more classes was `nan`.*"): -# average_precision(pred, target) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index e3dc891f43f..4060ec590a9 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -217,83 +217,3 @@ def test_multiclass_calibration_error_dtype_gpu(self, input, dtype): metric_args={"num_classes": NUM_CLASSES}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# @pytest.mark.parametrize("n_bins", [10, 15, 20]) -# @pytest.mark.parametrize("norm", ["l1", "l2", "max"]) -# @pytest.mark.parametrize( -# "preds, target", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target), -# (_input_binary_logits.preds, _input_binary_logits.target), -# (_input_mcls_prob.preds, _input_mcls_prob.target), -# (_input_mcls_logits.preds, _input_mcls_logits.target), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target), -# ], -# ) -# class TestCE(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=CalibrationError, -# sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"n_bins": n_bins, "norm": norm}, -# ) - -# def test_ce_functional(self, preds, target, n_bins, norm): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=calibration_error, -# sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), -# metric_args={"n_bins": n_bins, "norm": norm}, -# ) - - -# @pytest.mark.parametrize("preds, targets", [(_input_mlb_prob.preds, _input_mlb_prob.target)]) -# def test_invalid_input(preds, targets): -# for p, t in zip(preds, targets): -# with pytest.raises( -# ValueError, -# match=re.escape( -# f"Calibration error is not well-defined for data with size {p.size()} and targets {t.size()}." -# ), -# ): -# calibration_error(p, t) - - -# @pytest.mark.parametrize( -# "preds, target", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target), -# (_input_mcls_prob.preds, _input_mcls_prob.target), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target), -# ], -# ) -# def test_invalid_norm(preds, target): -# with pytest.raises(ValueError, match="Norm l3 is not supported. Please select from l1, l2, or max. "): -# calibration_error(preds, target, norm="l3") - - -# @pytest.mark.parametrize("n_bins", [-10, -1, "fsd"]) -# @pytest.mark.parametrize( -# "preds, targets", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target), -# (_input_mcls_prob.preds, _input_mcls_prob.target), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target), -# ], -# ) -# def test_invalid_bins(preds, targets, n_bins): -# for p, t in zip(preds, targets): -# with pytest.raises( -# ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}" -# ): -# calibration_error(p, t, n_bins=n_bins) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 5b734a90816..39955efdeb0 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -213,118 +213,3 @@ def test_multiclass_confusion_matrix_dtypes_gpu(self, input, dtype): metric_args={"num_classes": NUM_CLASSES}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_cohen_kappa_binary_prob(preds, target, weights=None): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_binary(preds, target, weights=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_multilabel_prob(preds, target, weights=None): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_multilabel(preds, target, weights=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_multiclass_prob(preds, target, weights=None): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_multiclass(preds, target, weights=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_multidim_multiclass_prob(preds, target, weights=None): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - - -# @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_cohen_kappa_binary_prob, 2), -# (_input_binary.preds, _input_binary.target, _sk_cohen_kappa_binary, 2), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cohen_kappa_multilabel_prob, 2), -# (_input_mlb.preds, _input_mlb.target, _sk_cohen_kappa_multilabel, 2), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cohen_kappa_multiclass_prob, NUM_CLASSES), -# (_input_mcls.preds, _input_mcls.target, _sk_cohen_kappa_multiclass, NUM_CLASSES), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cohen_kappa_multidim_multiclass_prob, NUM_CLASSES), -# (_input_mdmc.preds, _input_mdmc.target, _sk_cohen_kappa_multidim_multiclass, NUM_CLASSES), -# ], -# ) -# class TestCohenKappa(MetricTester): -# atol = 1e-5 - -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=CohenKappa, -# sk_metric=partial(sk_metric, weights=weights), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, -# ) - -# def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=cohen_kappa, -# sk_metric=partial(sk_metric, weights=weights), -# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, -# ) - -# def test_cohen_kappa_differentiability(self, preds, target, sk_metric, weights, num_classes): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=CohenKappa, -# metric_functional=cohen_kappa, -# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, -# ) - - -# def test_warning_on_wrong_weights(tmpdir): -# preds = torch.randint(3, size=(20,)) -# target = torch.randint(3, size=(20,)) - -# with pytest.raises(ValueError, match=".* ``weights`` but should be either None, 'linear' or 'quadratic'"): -# cohen_kappa(preds, target, num_classes=3, weights="unknown_arg") diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index e3a3ff7bcf3..2d58377e206 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -323,157 +323,3 @@ def test_warning_on_nan(): match=".* NaN values found in confusion matrix have been replaced with zeros.", ): multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") - - -# -------------------------- Old stuff -------------------------- - -# def _sk_cm_binary_prob(preds, target, normalize=None): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -# def _sk_cm_binary(preds, target, normalize=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -# def _sk_cm_multilabel_prob(preds, target, normalize=None): -# sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.numpy() - -# cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) -# if normalize is not None: -# if normalize == "true": -# cm = cm / cm.sum(axis=1, keepdims=True) -# elif normalize == "pred": -# cm = cm / cm.sum(axis=0, keepdims=True) -# elif normalize == "all": -# cm = cm / cm.sum() -# cm[np.isnan(cm)] = 0 -# return cm - - -# def _sk_cm_multilabel(preds, target, normalize=None): -# sk_preds = preds.numpy() -# sk_target = target.numpy() - -# cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) -# if normalize is not None: -# if normalize == "true": -# cm = cm / cm.sum(axis=1, keepdims=True) -# elif normalize == "pred": -# cm = cm / cm.sum(axis=0, keepdims=True) -# elif normalize == "all": -# cm = cm / cm.sum() -# cm[np.isnan(cm)] = 0 -# return cm - - -# def _sk_cm_multiclass_prob(preds, target, normalize=None): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -# def _sk_cm_multiclass(preds, target, normalize=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -# def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -# def _sk_cm_multidim_multiclass(preds, target, normalize=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes, multilabel", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), -# (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), -# (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), -# (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), -# (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), -# (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), -# (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), -# (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False), -# ], -# ) -# class TestConfusionMatrix(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_confusion_matrix( -# self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step -# ): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=ConfusionMatrix, -# sk_metric=partial(sk_metric, normalize=normalize), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "num_classes": num_classes, -# "threshold": THRESHOLD, -# "normalize": normalize, -# "multilabel": multilabel, -# }, -# ) - -# def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=confusion_matrix, -# sk_metric=partial(sk_metric, normalize=normalize), -# metric_args={ -# "num_classes": num_classes, -# "threshold": THRESHOLD, -# "normalize": normalize, -# "multilabel": multilabel, -# }, -# ) - -# def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=ConfusionMatrix, -# metric_functional=confusion_matrix, -# metric_args={ -# "num_classes": num_classes, -# "threshold": THRESHOLD, -# "normalize": normalize, -# "multilabel": multilabel, -# }, -# ) - - -# def test_warning_on_nan(tmpdir): -# preds = torch.randint(3, size=(20,)) -# target = torch.randint(3, size=(20,)) - -# with pytest.warns( -# UserWarning, -# match=".* nan values found in confusion matrix have been replaced with zeros.", -# ): -# confusion_matrix(preds, target, num_classes=5, normalize="true") diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 72d8060757b..868b3f9613c 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -549,428 +549,3 @@ def test_multilabel_fbeta_score_half_gpu(self, input, module, functional, compar metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None): -# if average == "none": -# average = None -# if num_classes == 1: -# average = "binary" - -# labels = list(range(num_classes)) -# try: -# labels.remove(ignore_index) -# except ValueError: -# pass - -# sk_preds, sk_target, _ = _input_format_classification( -# preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass -# ) -# sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() -# sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - -# if len(labels) != num_classes and not average: -# sk_scores = np.insert(sk_scores, ignore_index, np.nan) - -# return sk_scores - - -# def _sk_fbeta_f1_multidim_multiclass( -# preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average -# ): -# preds, target, _ = _input_format_classification( -# preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass -# ) - -# if mdmc_average == "global": -# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) -# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - -# return _sk_fbeta_f1(preds, target, sk_fn, num_classes, average, False, ignore_index) -# if mdmc_average == "samplewise": -# scores = [] - -# for i in range(preds.shape[0]): -# pred_i = preds[i, ...].T -# target_i = target[i, ...].T -# scores_i = _sk_fbeta_f1(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) - -# scores.append(np.expand_dims(scores_i, 0)) - -# return np.concatenate(scores).mean(axis=0) - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), -# (F1Score, f1_score_pl), -# ], -# ) -# @pytest.mark.parametrize( -# "average, mdmc_average, num_classes, ignore_index, match_str", -# [ -# ("wrong", None, None, None, "`average`"), -# ("micro", "wrong", None, None, "`mdmc"), -# ("macro", None, None, None, "number of classes"), -# ("macro", None, 1, 0, "ignore_index"), -# ], -# ) -# def test_wrong_params(metric_class, metric_fn, average, mdmc_average, num_classes, ignore_index, match_str): -# with pytest.raises(ValueError, match=match_str): -# metric_class( -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - -# with pytest.raises(ValueError, match=match_str): -# metric_fn( -# _input_binary.preds[0], -# _input_binary.target[0], -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), -# (F1Score, f1_score_pl), -# ], -# ) -# def test_zero_division(metric_class, metric_fn): -# """Test that zero_division works correctly (currently should just set to 0).""" - -# preds = torch.tensor([1, 2, 1, 1]) -# target = torch.tensor([2, 0, 2, 1]) - -# cl_metric = metric_class(average="none", num_classes=3) -# cl_metric(preds, target) - -# result_cl = cl_metric.compute() -# result_fn = metric_fn(preds, target, average="none", num_classes=3) - -# assert result_cl[0] == result_fn[0] == 0 - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), -# (F1Score, f1_score_pl), -# ], -# ) -# def test_no_support(metric_class, metric_fn): -# """This tests a rare edge case, where there is only one class present. - -# in target, and ignore_index is set to exactly that class - and the -# average method is equal to 'weighted'. - -# This would mean that the sum of weights equals zero, and would, without -# taking care of this case, return NaN. However, the reduction function -# should catch that and set the metric to equal the value of zero_division -# in this case (zero_division is for now not configurable and equals 0). -# """ - -# preds = torch.tensor([1, 1, 0, 0]) -# target = torch.tensor([0, 0, 0, 0]) - -# cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) -# cl_metric(preds, target) - -# result_cl = cl_metric.compute() -# result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) - -# assert result_cl == result_fn == 0 - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), -# (F1Score, f1_score_pl), -# ], -# ) -# @pytest.mark.parametrize( -# "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -# ) -# def test_class_not_present(metric_class, metric_fn, ignore_index, expected): -# """This tests that when metric is computed per class and a given class is not present in both the `preds` and -# `target`, the resulting score is `nan`.""" -# preds = torch.tensor([0, 0, 0]) -# target = torch.tensor([0, 0, 0]) -# num_classes = 2 - -# # test functional -# result_fn = metric_fn( -# preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index -# ) -# assert torch.allclose(expected, result_fn, equal_nan=True) - -# # test class -# cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) -# cl_metric(preds, target) -# result_cl = cl_metric.compute() -# assert torch.allclose(expected, result_cl, equal_nan=True) - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn, sk_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0), partial(fbeta_score, beta=2.0)), -# (F1Score, f1_score_pl, f1_score), -# ], -# ) -# @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -# @pytest.mark.parametrize("ignore_index", [None, 0]) -# @pytest.mark.parametrize( -# "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", -# [ -# (_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_fbeta_f1), -# (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_fbeta_f1), -# (_input_binary.preds, _input_binary.target, 1, False, None, _sk_fbeta_f1), -# (_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1), -# (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1), -# (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_fbeta_f1), -# (_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_fbeta_f1), -# (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_fbeta_f1), -# (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_fbeta_f1), -# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_fbeta_f1_multidim_multiclass), -# ( -# _input_mdmc_prob.preds, -# _input_mdmc_prob.target, -# NUM_CLASSES, -# None, -# "global", -# _sk_fbeta_f1_multidim_multiclass, -# ), -# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_fbeta_f1_multidim_multiclass), -# ( -# _input_mdmc_prob.preds, -# _input_mdmc_prob.target, -# NUM_CLASSES, -# None, -# "samplewise", -# _sk_fbeta_f1_multidim_multiclass, -# ), -# ], -# ) -# class TestFBeta(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_fbeta_f1( -# self, -# ddp: bool, -# dist_sync_on_step: bool, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# sk_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=metric_class, -# sk_metric=partial( -# sk_wrapper, -# sk_fn=sk_fn, -# average=average, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# mdmc_average=mdmc_average, -# ), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - -# def test_fbeta_f1_functional( -# self, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# sk_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=metric_fn, -# sk_metric=partial( -# sk_wrapper, -# sk_fn=sk_fn, -# average=average, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# mdmc_average=mdmc_average, -# ), -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - -# def test_fbeta_f1_differentiability( -# self, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# sk_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_differentiability_test( -# preds, -# target, -# metric_functional=metric_fn, -# metric_module=metric_class, -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - - -# _mc_k_target = torch.tensor([0, 1, 2]) -# _mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -# _ml_k_target = torch.tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -# _ml_k_preds = torch.tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0)), -# (F1Score, fbeta_score_pl), -# ], -# ) -# @pytest.mark.parametrize( -# "k, preds, target, average, expected_fbeta, expected_f1", -# [ -# (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor(2 / 3), torch.tensor(2 / 3)), -# (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor(5 / 6), torch.tensor(2 / 3)), -# (1, _ml_k_preds, _ml_k_target, "micro", torch.tensor(0.0), torch.tensor(0.0)), -# (2, _ml_k_preds, _ml_k_target, "micro", torch.tensor(5 / 18), torch.tensor(2 / 9)), -# ], -# ) -# def test_top_k( -# metric_class, -# metric_fn, -# k: int, -# preds: Tensor, -# target: Tensor, -# average: str, -# expected_fbeta: Tensor, -# expected_f1: Tensor, -# ): -# """A simple test to check that top_k works as expected. - -# Just a sanity check, the tests in FBeta should already guarantee the corectness of results. -# """ -# class_metric = metric_class(top_k=k, average=average, num_classes=3) -# class_metric.update(preds, target) - -# if class_metric.beta != 1.0: -# result = expected_fbeta -# else: -# result = expected_f1 - -# assert torch.isclose(class_metric.compute(), result) -# assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) - - -# @pytest.mark.parametrize("ignore_index", [None, 2]) -# @pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) -# @pytest.mark.parametrize( -# "metric_class, metric_functional, sk_fn", -# [ -# (partial(FBetaScore, beta=2.0), partial(fbeta_score_pl, beta=2.0), partial(fbeta_score, beta=2.0)), -# (F1Score, f1_score_pl, f1_score), -# ], -# ) -# def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_index): -# preds = _input_miss_class.preds -# target = _input_miss_class.target -# preds_flat = torch.cat(list(preds), dim=0) -# target_flat = torch.cat(list(target), dim=0) - -# mc = metric_class(num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index) -# for i in range(NUM_BATCHES): -# mc.update(preds[i], target[i]) -# class_res = mc.compute() -# func_res = metric_functional( -# preds_flat, target_flat, num_classes=NUM_CLASSES, average=average, ignore_index=ignore_index -# ) -# sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=0) - -# assert torch.allclose(class_res, torch.tensor(sk_res).float()) -# assert torch.allclose(func_res, torch.tensor(sk_res).float()) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 6011ad5cbcc..e4c64704818 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -484,77 +484,3 @@ def test_multilabel_hamming_distance_dtype_gpu(self, input, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_hamming_loss(preds, target): -# sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) -# sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() -# sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - -# return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) - - -# @pytest.mark.parametrize( -# "preds, target", -# [ -# (_input_binary_logits.preds, _input_binary_logits.target), -# (_input_binary_prob.preds, _input_binary_prob.target), -# (_input_binary.preds, _input_binary.target), -# (_input_mlb_logits.preds, _input_mlb_logits.target), -# (_input_mlb_prob.preds, _input_mlb_prob.target), -# (_input_mlb.preds, _input_mlb.target), -# (_input_mcls_logits.preds, _input_mcls_logits.target), -# (_input_mcls_prob.preds, _input_mcls_prob.target), -# (_input_mcls.preds, _input_mcls.target), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target), -# (_input_mdmc.preds, _input_mdmc.target), -# (_input_mlmd_prob.preds, _input_mlmd_prob.target), -# (_input_mlmd.preds, _input_mlmd.target), -# ], -# ) -# class TestHammingDistance(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [False, True]) -# def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=HammingDistance, -# sk_metric=_sk_hamming_loss, -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"threshold": THRESHOLD}, -# ) - -# def test_hamming_distance_fn(self, preds, target): -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=hamming_distance, -# sk_metric=_sk_hamming_loss, -# metric_args={"threshold": THRESHOLD}, -# ) - -# def test_hamming_distance_differentiability(self, preds, target): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=HammingDistance, -# metric_functional=hamming_distance, -# metric_args={"threshold": THRESHOLD}, -# ) - - -# @pytest.mark.parametrize("threshold", [1.5]) -# def test_wrong_params(threshold): -# preds, target = _input_mcls_prob.preds, _input_mcls_prob.target - -# with pytest.raises(ValueError): -# ham_dist = HammingDistance(threshold=threshold) -# ham_dist(preds, target) -# ham_dist.compute() - -# with pytest.raises(ValueError): -# hamming_distance(preds, target, threshold=threshold) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index ec4f883f712..a0bf4904b55 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -208,134 +208,3 @@ def test_multiclass_hinge_loss_dtype_gpu(self, input, dtype): metric_args={"num_classes": NUM_CLASSES}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# _input_binary = Input( -# preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) -# ) - -# _input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) - -# _input_multiclass = Input( -# preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -# target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -# ) - - -# def _sk_hinge(preds, target, squared, multiclass_mode): -# sk_preds, sk_target = preds.numpy(), target.numpy() - -# if multiclass_mode == MulticlassMode.ONE_VS_ALL: -# enc = OneHotEncoder() -# enc.fit(sk_target.reshape(-1, 1)) -# sk_target = enc.transform(sk_target.reshape(-1, 1)).toarray() - -# if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: -# sk_target = 2 * sk_target - 1 - -# if squared or sk_target.max() != 1 or sk_target.min() != -1: -# # Squared not an option in sklearn and infers classes incorrectly with single element, so adapted from source -# if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: -# margin = sk_target * sk_preds -# else: -# mask = np.ones_like(sk_preds, dtype=bool) -# mask[np.arange(sk_target.shape[0]), sk_target] = False -# margin = sk_preds[~mask] -# margin -= np.max(sk_preds[mask].reshape(sk_target.shape[0], -1), axis=1) -# measures = 1 - margin -# measures = np.clip(measures, 0, None) - -# if squared: -# measures = measures**2 -# return measures.mean(axis=0) -# if multiclass_mode == MulticlassMode.ONE_VS_ALL: -# result = np.zeros(sk_preds.shape[1]) -# for i in range(result.shape[0]): -# result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i]) -# return result - -# return sk_hinge(y_true=sk_target, pred_decision=sk_preds) - - -# @pytest.mark.parametrize( -# "preds, target, squared, multiclass_mode", -# [ -# (_input_binary.preds, _input_binary.target, False, None), -# (_input_binary.preds, _input_binary.target, True, None), -# (_input_binary_single.preds, _input_binary_single.target, False, None), -# (_input_binary_single.preds, _input_binary_single.target, True, None), -# (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.CRAMMER_SINGER), -# (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.CRAMMER_SINGER), -# (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.ONE_VS_ALL), -# (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.ONE_VS_ALL), -# ], -# ) -# class TestHinge(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=HingeLoss, -# sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "squared": squared, -# "multiclass_mode": multiclass_mode, -# }, -# ) - -# def test_hinge_fn(self, preds, target, squared, multiclass_mode): -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), -# sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), -# ) - -# def test_hinge_differentiability(self, preds, target, squared, multiclass_mode): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=HingeLoss, -# metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), -# ) - - -# _input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) - -# _input_binary_different_sizes = Input( -# preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) -# ) - -# _input_multi_different_sizes = Input( -# preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)) -# ) - -# _input_extra_dim = Input( -# preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) -# ) - - -# @pytest.mark.parametrize( -# "preds, target, multiclass_mode", -# [ -# (_input_multi_target.preds, _input_multi_target.target, None), -# (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), -# (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), -# (_input_extra_dim.preds, _input_extra_dim.target, None), -# (_input_multiclass.preds[0], _input_multiclass.target[0], "invalid_mode"), -# ], -# ) -# def test_bad_inputs_fn(preds, target, multiclass_mode): -# with pytest.raises(ValueError): -# _ = hinge_loss(preds, target, multiclass_mode=multiclass_mode) - - -# def test_bad_inputs_class(): -# with pytest.raises(ValueError): -# HingeLoss(multiclass_mode="invalid_mode") diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 8e35031b67e..4b6ad133211 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -315,213 +315,3 @@ def test_multilabel_jaccard_index_dtype_gpu(self, input, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_jaccard_binary_prob(preds, target, average=None): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_binary(preds, target, average=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_multilabel_prob(preds, target, average=None): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_multilabel(preds, target, average=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_multiclass_prob(preds, target, average=None): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_multiclass(preds, target, average=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_multidim_multiclass_prob(preds, target, average=None): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# def _sk_jaccard_multidim_multiclass(preds, target, average=None): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -# @pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"]) -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_jaccard_binary_prob, 2), -# (_input_binary.preds, _input_binary.target, _sk_jaccard_binary, 2), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_jaccard_multilabel_prob, 2), -# (_input_mlb.preds, _input_mlb.target, _sk_jaccard_multilabel, 2), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_jaccard_multiclass_prob, NUM_CLASSES), -# (_input_mcls.preds, _input_mcls.target, _sk_jaccard_multiclass, NUM_CLASSES), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_jaccard_multidim_multiclass_prob, NUM_CLASSES), -# (_input_mdmc.preds, _input_mdmc.target, _sk_jaccard_multidim_multiclass, NUM_CLASSES), -# ], -# ) -# class TestJaccardIndex(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): -# # average = "macro" if reduction == "elementwise_mean" else None # convert tags -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=JaccardIndex, -# sk_metric=partial(sk_metric, average=average), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, -# ) - -# def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes): -# # average = "macro" if reduction == "elementwise_mean" else None # convert tags -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=jaccard_index, -# sk_metric=partial(sk_metric, average=average), -# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, -# ) - -# def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=JaccardIndex, -# metric_functional=jaccard_index, -# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, -# ) - - -# @pytest.mark.parametrize( -# ["half_ones", "average", "ignore_index", "expected"], -# [ -# (False, "none", None, Tensor([1, 1, 1])), -# (False, "macro", None, Tensor([1])), -# (False, "none", 0, Tensor([1, 1])), -# (True, "none", None, Tensor([0.5, 0.5, 0.5])), -# (True, "macro", None, Tensor([0.5])), -# (True, "none", 0, Tensor([2 / 3, 1 / 2])), -# ], -# ) -# def test_jaccard(half_ones, average, ignore_index, expected): -# preds = (torch.arange(120) % 3).view(-1, 1) -# target = (torch.arange(120) % 3).view(-1, 1) -# if half_ones: -# preds[:60] = 1 -# jaccard_val = jaccard_index( -# preds=preds, -# target=target, -# average=average, -# num_classes=3, -# ignore_index=ignore_index, -# # reduction=reduction, -# ) -# assert torch.allclose(jaccard_val, expected, atol=1e-9) - - -# # test `absent_score` -# @pytest.mark.parametrize( -# ["pred", "target", "ignore_index", "absent_score", "num_classes", "expected"], -# [ -# # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid -# # scores the function can return ([0., 1.] range, inclusive). -# # 2 classes, class 0 is correct everywhere, class 1 is absent. -# ([0], [0], None, -1.0, 2, [1.0, -1.0]), -# ([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]), -# # absent_score not applied if only class 0 is present and it's the only class. -# ([0], [0], None, -1.0, 1, [1.0]), -# # 2 classes, class 1 is correct everywhere, class 0 is absent. -# ([1], [1], None, -1.0, 2, [-1.0, 1.0]), -# ([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]), -# # When 0 index ignored, class 0 does not get a score (not even the absent_score). -# ([1], [1], 0, -1.0, 2, [1.0]), -# # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. -# ([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]), -# ([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]), -# # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. -# ([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]), -# ([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]), -# # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class -# # 2 is absent. -# ([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]), -# # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class -# # 2 is absent. -# ([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]), -# # Sanity checks with absent_score of 1.0. -# ([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]), -# ([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), -# ], -# ) -# def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): -# jaccard_val = jaccard_index( -# preds=tensor(pred), -# target=tensor(target), -# average=None, -# ignore_index=ignore_index, -# absent_score=absent_score, -# num_classes=num_classes, -# # reduction="none", -# ) -# assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) - - -# # example data taken from -# # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -# @pytest.mark.parametrize( -# ["pred", "target", "ignore_index", "num_classes", "average", "expected"], -# [ -# # Ignoring an index outside of [0, num_classes-1] should have no effect. -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]), -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]), -# # Ignoring a valid index drops only that index from the result. -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]), -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]), -# # When reducing to mean or sum, the ignored index does not contribute to the output. -# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]), -# # ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), -# ], -# ) -# def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected): -# jaccard_val = jaccard_index( -# preds=tensor(pred), -# target=tensor(target), -# average=average, -# ignore_index=ignore_index, -# num_classes=num_classes, -# # reduction=reduction, -# ) -# assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 351e13d459f..93b3371c0f4 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -301,125 +301,3 @@ def test_zero_case_in_multiclass(): # Example where neither 1 or 2 is present in the target tensor out = multiclass_matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) assert out == 0.0 - - -# -------------------------- Old stuff -------------------------- - -# def _sk_matthews_corrcoef_binary_prob(preds, target): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_binary(preds, target): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_multilabel_prob(preds, target): -# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_multilabel(preds, target): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_multiclass_prob(preds, target): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_multiclass(preds, target): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target): -# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# def _sk_matthews_corrcoef_multidim_multiclass(preds, target): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2), -# (_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2), -# (_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES), -# (_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES), -# ( -# _input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES -# ), -# (_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES), -# ], -# ) -# class TestMatthewsCorrCoef(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=MatthewsCorrCoef, -# sk_metric=sk_metric, -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "num_classes": num_classes, -# "threshold": THRESHOLD, -# }, -# ) - -# def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=matthews_corrcoef, -# sk_metric=sk_metric, -# metric_args={ -# "num_classes": num_classes, -# "threshold": THRESHOLD, -# }, -# ) - -# def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=MatthewsCorrCoef, -# metric_functional=matthews_corrcoef, -# metric_args={ -# "num_classes": num_classes, -# "threshold": THRESHOLD, -# }, -# ) - - -# def test_zero_case(): -# """Cases where the denominator in the matthews corrcoef is 0, the score should return 0.""" -# # Example where neither 1 or 2 is present in the target tensor -# out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) -# assert out == 0.0 diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index f17daf2f4c1..75a2873c7fd 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -540,439 +540,3 @@ def test_multilabel_precision_recall_half_gpu(self, input, module, functional, c metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_prec_recall(preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average=None): -# # todo: `mdmc_average` is unused -# if average == "none": -# average = None -# if num_classes == 1: -# average = "binary" - -# labels = list(range(num_classes)) -# try: -# labels.remove(ignore_index) -# except ValueError: -# pass - -# sk_preds, sk_target, _ = _input_format_classification( -# preds, target, THRESHOLD, num_classes=num_classes, multiclass=multiclass -# ) -# sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - -# sk_scores = sk_fn(sk_target, sk_preds, average=average, zero_division=0, labels=labels) - -# if len(labels) != num_classes and not average: -# sk_scores = np.insert(sk_scores, ignore_index, np.nan) - -# return sk_scores - - -# def _sk_prec_recall_multidim_multiclass( -# preds, target, sk_fn, num_classes, average, multiclass, ignore_index, mdmc_average -# ): -# preds, target, _ = _input_format_classification( -# preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass -# ) - -# if mdmc_average == "global": -# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) -# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - -# return _sk_prec_recall(preds, target, sk_fn, num_classes, average, False, ignore_index) -# if mdmc_average == "samplewise": -# scores = [] - -# for i in range(preds.shape[0]): -# pred_i = preds[i, ...].T -# target_i = target[i, ...].T -# scores_i = _sk_prec_recall(pred_i, target_i, sk_fn, num_classes, average, False, ignore_index) - -# scores.append(np.expand_dims(scores_i, 0)) - -# return np.concatenate(scores).mean(axis=0) - - -# @pytest.mark.parametrize("metric, fn_metric", [(Precision, precision), (Recall, recall)]) -# @pytest.mark.parametrize( -# "average, mdmc_average, num_classes, ignore_index, match_str", -# [ -# ("wrong", None, None, None, "`average`"), -# ("micro", "wrong", None, None, "`mdmc"), -# ("macro", None, None, None, "number of classes"), -# ("macro", None, 1, 0, "ignore_index"), -# ], -# ) -# def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): -# with pytest.raises(ValueError, match=match_str): -# metric( -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - -# with pytest.raises(ValueError, match=match_str): -# fn_metric( -# _input_binary.preds[0], -# _input_binary.target[0], -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - -# with pytest.raises(ValueError, match=match_str): -# precision_recall( -# _input_binary.preds[0], -# _input_binary.target[0], -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -# def test_zero_division(metric_class, metric_fn): -# """Test that zero_division works correctly (currently should just set to 0).""" - -# preds = tensor([0, 2, 1, 1]) -# target = tensor([2, 1, 2, 1]) - -# cl_metric = metric_class(average="none", num_classes=3) -# cl_metric(preds, target) - -# result_cl = cl_metric.compute() -# result_fn = metric_fn(preds, target, average="none", num_classes=3) - -# assert result_cl[0] == result_fn[0] == 0 - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -# def test_no_support(metric_class, metric_fn): -# """This tests a rare edge case, where there is only one class present. - -# in target, and ignore_index is set to exactly that class - and the -# average method is equal to 'weighted'. - -# This would mean that the sum of weights equals zero, and would, without -# taking care of this case, return NaN. However, the reduction function -# should catch that and set the metric to equal the value of zero_division -# in this case (zero_division is for now not configurable and equals 0). -# """ - -# preds = tensor([1, 1, 0, 0]) -# target = tensor([0, 0, 0, 0]) - -# cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=0) -# cl_metric(preds, target) - -# result_cl = cl_metric.compute() -# result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=0) - -# assert result_cl == result_fn == 0 - - -# @pytest.mark.parametrize( -# "metric_class, metric_fn, sk_fn", [(Recall, recall, recall_score), (Precision, precision, precision_score)] -# ) -# @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -# @pytest.mark.parametrize("ignore_index", [None, 0]) -# @pytest.mark.parametrize( -# "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", -# [ -# (_input_binary_logits.preds, _input_binary_logits.target, 1, None, None, _sk_prec_recall), -# (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_prec_recall), -# (_input_binary.preds, _input_binary.target, 1, False, None, _sk_prec_recall), -# (_input_mlb_logits.preds, _input_mlb_logits.target, NUM_CLASSES, None, None, _sk_prec_recall), -# (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), -# (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_prec_recall), -# (_input_mcls_logits.preds, _input_mcls_logits.target, NUM_CLASSES, None, None, _sk_prec_recall), -# (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_prec_recall), -# (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_prec_recall), -# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_prec_recall_multidim_multiclass), -# ( -# _input_mdmc_prob.preds, -# _input_mdmc_prob.target, -# NUM_CLASSES, -# None, -# "global", -# _sk_prec_recall_multidim_multiclass, -# ), -# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_prec_recall_multidim_multiclass), -# ( -# _input_mdmc_prob.preds, -# _input_mdmc_prob.target, -# NUM_CLASSES, -# None, -# "samplewise", -# _sk_prec_recall_multidim_multiclass, -# ), -# ], -# ) -# class TestPrecisionRecall(MetricTester): -# @pytest.mark.parametrize("ddp", [False, True]) -# @pytest.mark.parametrize("dist_sync_on_step", [False]) -# def test_precision_recall_class( -# self, -# ddp: bool, -# dist_sync_on_step: bool, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# sk_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# # todo: `metric_fn` is unused -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=metric_class, -# sk_metric=partial( -# sk_wrapper, -# sk_fn=sk_fn, -# average=average, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# mdmc_average=mdmc_average, -# ), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - -# def test_precision_recall_fn( -# self, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# sk_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# # todo: `metric_class` is unused -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=metric_fn, -# sk_metric=partial( -# sk_wrapper, -# sk_fn=sk_fn, -# average=average, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# mdmc_average=mdmc_average, -# ), -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - -# def test_precision_recall_differentiability( -# self, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# sk_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# # todo: `metric_class` is unused -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=metric_class, -# metric_functional=metric_fn, -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - - -# @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -# def test_precision_recall_joint(average): -# """A simple test of the joint precision_recall metric. - -# No need to test this thorougly, as it is just a combination of precision and recall, which are already tested -# thoroughly. -# """ - -# precision_result = precision( -# _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES -# ) -# recall_result = recall( -# _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES -# ) - -# prec_recall_result = precision_recall( -# _input_mcls_prob.preds[0], _input_mcls_prob.target[0], average=average, num_classes=NUM_CLASSES -# ) - -# assert torch.equal(precision_result, prec_recall_result[0]) -# assert torch.equal(recall_result, prec_recall_result[1]) - - -# _mc_k_target = tensor([0, 1, 2]) -# _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -# _ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -# _ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Recall, recall), (Precision, precision)]) -# @pytest.mark.parametrize( -# "k, preds, target, average, expected_prec, expected_recall", -# [ -# (1, _mc_k_preds, _mc_k_target, "micro", tensor(2 / 3), tensor(2 / 3)), -# (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2), tensor(1.0)), -# (1, _ml_k_preds, _ml_k_target, "micro", tensor(0.0), tensor(0.0)), -# (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6), tensor(1 / 3)), -# ], -# ) -# def test_top_k( -# metric_class, -# metric_fn, -# k: int, -# preds: Tensor, -# target: Tensor, -# average: str, -# expected_prec: Tensor, -# expected_recall: Tensor, -# ): -# """A simple test to check that top_k works as expected. - -# Just a sanity check, the tests in StatScores should already guarantee the correctness of results. -# """ - -# class_metric = metric_class(top_k=k, average=average, num_classes=3) -# class_metric.update(preds, target) - -# if metric_class.__name__ == "Precision": -# result = expected_prec -# else: -# result = expected_recall - -# assert torch.equal(class_metric.compute(), result) -# assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Precision, precision), (Recall, recall)]) -# @pytest.mark.parametrize( -# "ignore_index, expected", [(None, torch.tensor([1.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -# ) -# def test_class_not_present(metric_class, metric_fn, ignore_index, expected): -# """This tests that when metric is computed per class and a given class is not present in both the `preds` and -# `target`, the resulting score is `nan`.""" -# preds = torch.tensor([0, 0, 0]) -# target = torch.tensor([0, 0, 0]) -# num_classes = 2 - -# # test functional -# result_fn = metric_fn( -# preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index -# ) -# assert torch.allclose(expected, result_fn, equal_nan=True) - -# # test class -# cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) -# cl_metric(preds, target) -# result_cl = cl_metric.compute() -# assert torch.allclose(expected, result_cl, equal_nan=True) - - -# @pytest.mark.parametrize("average", ["micro", "macro", "weighted"]) -# @pytest.mark.parametrize( -# "metric_class, metric_functional, sk_fn", -# [ -# (Precision, precision, precision_score), -# (Recall, recall, recall_score) -# ] -# ) -# def test_same_input(metric_class, metric_functional, sk_fn, average): -# preds = _input_miss_class.preds -# target = _input_miss_class.target -# preds_flat = torch.cat(list(preds), dim=0) -# target_flat = torch.cat(list(target), dim=0) - -# mc = metric_class(num_classes=NUM_CLASSES, average=average) -# for i in range(NUM_BATCHES): -# mc.update(preds[i], target[i]) -# class_res = mc.compute() -# func_res = metric_functional(preds_flat, target_flat, num_classes=NUM_CLASSES, average=average) -# sk_res = sk_fn(target_flat, preds_flat, average=average, zero_division=1) - -# assert torch.allclose(class_res, torch.tensor(sk_res).float()) -# assert torch.allclose(func_res, torch.tensor(sk_res).float()) - - -# @pytest.mark.parametrize("metric_cls", [Precision, Recall]) -# def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): -# prec = MetricWrapper(metric_cls(average="none", num_classes=noneavg["pred1"].shape[1])) -# result1 = prec(noneavg["pred1"], noneavg["target1"]) -# assert torch.allclose(noneavg["res1"], result1, equal_nan=True) -# result2 = prec(noneavg["pred2"], noneavg["target2"]) -# assert torch.allclose(noneavg["res2"], result2, equal_nan=True) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index a9d5330aaf8..9bebfc9c40b 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -351,121 +351,3 @@ def test_multilabel_precision_recall_curve_threshold_arg(self, input, threshold_ assert torch.allclose(p1[i], p2[i]) assert torch.allclose(r1[i], r2[i]) assert torch.allclose(t1[i], t2) - - -# -------------------------- Old stuff -------------------------- - - -# def _sk_precision_recall_curve(y_true, probas_pred, num_classes=1): -# """Adjusted comparison function that can also handles multiclass.""" -# if num_classes == 1: -# return sk_precision_recall_curve(y_true, probas_pred) - -# precision, recall, thresholds = [], [], [] -# for i in range(num_classes): -# y_true_temp = np.zeros_like(y_true) -# y_true_temp[y_true == i] = 1 -# res = sk_precision_recall_curve(y_true_temp, probas_pred[:, i]) -# precision.append(res[0]) -# recall.append(res[1]) -# thresholds.append(res[2]) -# return precision, recall, thresholds - - -# def _sk_prec_rc_binary_prob(preds, target, num_classes=1): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -# def _sk_prec_rc_multiclass_prob(preds, target, num_classes=1): -# sk_preds = preds.reshape(-1, num_classes).numpy() -# sk_target = target.view(-1).numpy() - -# return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -# def _sk_prec_rc_multidim_multiclass_prob(preds, target, num_classes=1): -# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# sk_target = target.view(-1).numpy() -# return _sk_precision_recall_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_prec_rc_binary_prob, 1), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_prec_rc_multiclass_prob, NUM_CLASSES), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_prec_rc_multidim_multiclass_prob, NUM_CLASSES), -# ], -# ) -# class TestPrecisionRecallCurve(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_precision_recall_curve(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=PrecisionRecallCurve, -# sk_metric=partial(sk_metric, num_classes=num_classes), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"num_classes": num_classes}, -# ) - -# def test_precision_recall_curve_functional(self, preds, target, sk_metric, num_classes): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=precision_recall_curve, -# sk_metric=partial(sk_metric, num_classes=num_classes), -# metric_args={"num_classes": num_classes}, -# ) - -# def test_precision_recall_curve_differentiability(self, preds, target, sk_metric, num_classes): -# self.run_differentiability_test( -# preds, -# target, -# metric_module=PrecisionRecallCurve, -# metric_functional=precision_recall_curve, -# metric_args={"num_classes": num_classes}, -# ) - - -# @pytest.mark.parametrize( -# ["pred", "target", "expected_p", "expected_r", "expected_t"], -# [([1, 2, 3, 4], [1, 0, 0, 1], [0.5, 1 / 3, 0.5, 1.0, 1.0], [1, 0.5, 0.5, 0.5, 0.0], [1, 2, 3, 4])], -# ) -# def test_pr_curve(pred, target, expected_p, expected_r, expected_t): -# p, r, t = precision_recall_curve(tensor(pred), tensor(target)) -# assert p.size() == r.size() -# assert p.size(0) == t.size(0) + 1 - -# assert torch.allclose(p, tensor(expected_p).to(p)) -# assert torch.allclose(r, tensor(expected_r).to(r)) -# assert torch.allclose(t, tensor(expected_t).to(t)) - - -# @pytest.mark.parametrize( -# "sample_weight, pos_label, exp_shape", -# [(1, 1.0, 42), (None, 1.0, 42)], -# ) -# def test_binary_clf_curve(sample_weight, pos_label, exp_shape): -# # TODO: move back the pred and target to test func arguments -# # if you fix the array inside the function, you'd also have fix the shape, -# # because when the array changes, you also have to fix the shape -# seed_all(0) -# pred = torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100 -# target = tensor([0, 1] * 50, dtype=torch.int) -# if sample_weight is not None: -# sample_weight = torch.ones_like(pred) * sample_weight - -# fps, tps, thresh = _binary_clf_curve(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label) - -# assert isinstance(tps, Tensor) -# assert isinstance(fps, Tensor) -# assert isinstance(thresh, Tensor) -# assert tps.shape == (exp_shape,) -# assert fps.shape == (exp_shape,) -# assert thresh.shape == (exp_shape,) diff --git a/tests/unittests/classification/test_ranking.py b/tests/unittests/classification/test_ranking.py index dad5ba39448..280c0a26865 100644 --- a/tests/unittests/classification/test_ranking.py +++ b/tests/unittests/classification/test_ranking.py @@ -146,75 +146,3 @@ def test_multilabel_ranking_dtype_gpu(self, input, metric, functional_metric, sk metric_args={"num_labels": NUM_CLASSES}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_coverage_error(preds, target, sample_weight=None): -# if sample_weight is not None: -# sample_weight = sample_weight.numpy() -# return sk_coverage_error(target.numpy(), preds.numpy(), sample_weight=sample_weight) - - -# def _sk_label_ranking(preds, target, sample_weight=None): -# if sample_weight is not None: -# sample_weight = sample_weight.numpy() -# return sk_label_ranking(target.numpy(), preds.numpy(), sample_weight=sample_weight) - - -# def _sk_label_ranking_loss(preds, target, sample_weight=None): -# if sample_weight is not None: -# sample_weight = sample_weight.numpy() -# return sk_label_ranking_loss(target.numpy(), preds.numpy(), sample_weight=sample_weight) - - -# @pytest.mark.parametrize( -# "metric, functional_metric, sk_metric", -# [ -# (CoverageError, coverage_error, _sk_coverage_error), -# (LabelRankingAveragePrecision, label_ranking_average_precision, _sk_label_ranking), -# (LabelRankingLoss, label_ranking_loss, _sk_label_ranking_loss), -# ], -# ) -# @pytest.mark.parametrize( -# "preds, target", -# [ -# (_input_mlb_logits.preds, _input_mlb_logits.target), -# (_input_mlb_prob.preds, _input_mlb_prob.target), -# ], -# ) -# @pytest.mark.parametrize("sample_weight", [None, torch.rand(NUM_BATCHES, BATCH_SIZE)]) -# class TestRanking(MetricTester): -# @pytest.mark.parametrize("ddp", [False, True]) -# @pytest.mark.parametrize("dist_sync_on_step", [False, True]) -# def test_ranking_class( -# self, ddp, dist_sync_on_step, preds, target, metric, functional_metric, sk_metric, sample_weight -# ): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=metric, -# sk_metric=sk_metric, -# dist_sync_on_step=dist_sync_on_step, -# fragment_kwargs=True, -# sample_weight=sample_weight, -# ) - -# def test_ranking_functional(self, preds, target, metric, functional_metric, sk_metric, sample_weight): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=functional_metric, -# sk_metric=sk_metric, -# fragment_kwargs=True, -# sample_weight=sample_weight, -# ) - -# def test_ranking_differentiability(self, preds, target, metric, functional_metric, sk_metric, sample_weight): -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=metric, -# metric_functional=functional_metric, -# ) diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 30e66bd042c..c3f21c26089 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -340,140 +340,3 @@ def test_multilabel_roc_threshold_arg(self, input, threshold_fn): assert torch.allclose(p1[i], p2[i]) assert torch.allclose(r1[i], r2[i]) assert torch.allclose(t1[i], t2) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_roc_curve(y_true, probas_pred, num_classes: int = 1, multilabel: bool = False): -# """Adjusted comparison function that can also handles multiclass.""" -# if num_classes == 1: -# return sk_roc_curve(y_true, probas_pred, drop_intermediate=False) - -# fpr, tpr, thresholds = [], [], [] -# for i in range(num_classes): -# if multilabel: -# y_true_temp = y_true[:, i] -# else: -# y_true_temp = np.zeros_like(y_true) -# y_true_temp[y_true == i] = 1 - -# res = sk_roc_curve(y_true_temp, probas_pred[:, i], drop_intermediate=False) -# fpr.append(res[0]) -# tpr.append(res[1]) -# thresholds.append(res[2]) -# return fpr, tpr, thresholds - - -# def _sk_roc_binary_prob(preds, target, num_classes=1): -# sk_preds = preds.view(-1).numpy() -# sk_target = target.view(-1).numpy() - -# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -# def _sk_roc_multiclass_prob(preds, target, num_classes=1): -# sk_preds = preds.reshape(-1, num_classes).numpy() -# sk_target = target.view(-1).numpy() - -# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -# def _sk_roc_multidim_multiclass_prob(preds, target, num_classes=1): -# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# sk_target = target.view(-1).numpy() -# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes) - - -# def _sk_roc_multilabel_prob(preds, target, num_classes=1): -# sk_preds = preds.numpy() -# sk_target = target.numpy() -# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) - - -# def _sk_roc_multilabel_multidim_prob(preds, target, num_classes=1): -# sk_preds = preds.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# sk_target = target.transpose(0, 1).reshape(num_classes, -1).transpose(0, 1).numpy() -# return _sk_roc_curve(y_true=sk_target, probas_pred=sk_preds, num_classes=num_classes, multilabel=True) - - -# @pytest.mark.parametrize( -# "preds, target, sk_metric, num_classes", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_roc_binary_prob, 1), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_roc_multiclass_prob, NUM_CLASSES), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_roc_multidim_multiclass_prob, NUM_CLASSES), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_roc_multilabel_prob, NUM_CLASSES), -# (_input_mlmd_prob.preds, _input_mlmd_prob.target, _sk_roc_multilabel_multidim_prob, NUM_CLASSES), -# ], -# ) -# class TestROC(MetricTester): -# @pytest.mark.parametrize("ddp", [True, False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_roc(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=ROC, -# sk_metric=partial(sk_metric, num_classes=num_classes), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={"num_classes": num_classes}, -# ) - -# def test_roc_functional(self, preds, target, sk_metric, num_classes): -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=roc, -# sk_metric=partial(sk_metric, num_classes=num_classes), -# metric_args={"num_classes": num_classes}, -# ) - -# def test_roc_differentiability(self, preds, target, sk_metric, num_classes): -# self.run_differentiability_test( -# preds, -# target, -# metric_module=ROC, -# metric_functional=roc, -# metric_args={"num_classes": num_classes}, -# ) - - -# @pytest.mark.parametrize( -# ["pred", "target", "expected_tpr", "expected_fpr"], -# [ -# ([0, 1], [0, 1], [0, 1, 1], [0, 0, 1]), -# ([1, 0], [0, 1], [0, 0, 1], [0, 1, 1]), -# ([1, 1], [1, 0], [0, 1], [0, 1]), -# ([1, 0], [1, 0], [0, 1, 1], [0, 0, 1]), -# ([0.5, 0.5], [0, 1], [0, 1], [0, 1]), -# ], -# ) -# def test_roc_curve(pred, target, expected_tpr, expected_fpr): -# fpr, tpr, thresh = roc(tensor(pred), tensor(target)) - -# assert fpr.shape == tpr.shape -# assert fpr.size(0) == thresh.size(0) -# assert torch.allclose(fpr, tensor(expected_fpr).to(fpr)) -# assert torch.allclose(tpr, tensor(expected_tpr).to(tpr)) - - -# def test_warnings_on_missing_class(): -# """Test that a warning is given if either the positive or negative class is missing.""" -# metric = ROC() -# # no positive samples -# warning = ( -# "No positive samples in targets, true positive value should be meaningless." -# " Returning zero tensor in true positive score" -# ) -# with pytest.warns(UserWarning, match=warning): -# _, tpr, _ = metric(torch.randn(10).sigmoid(), torch.zeros(10)) -# assert all(tpr == 0) - -# warning = ( -# "No negative samples in targets, false positive value should be meaningless." -# " Returning zero tensor in false positive score" -# ) -# with pytest.warns(UserWarning, match=warning): -# fpr, _, _ = metric(torch.randn(10).sigmoid(), torch.ones(10)) -# assert all(fpr == 0) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index ca50d88ed49..0681fc98a78 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -526,381 +526,3 @@ def test_multilabel_specificity_dtype_gpu(self, input, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - - -# def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k): -# preds, target, _ = _input_format_classification( -# preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k -# ) -# sk_preds, sk_target = preds.numpy(), target.numpy() - -# if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: -# sk_preds = np.delete(sk_preds, ignore_index, 1) -# sk_target = np.delete(sk_target, ignore_index, 1) - -# if preds.shape[1] == 1 and reduce == "samples": -# sk_target = sk_target.T -# sk_preds = sk_preds.T - -# sk_stats = multilabel_confusion_matrix( -# sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 -# ) - -# if preds.shape[1] == 1 and reduce != "samples": -# sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] -# else: -# sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] - -# if reduce == "micro": -# sk_stats = sk_stats.sum(axis=0, keepdims=True) - -# sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) - -# if reduce == "micro": -# sk_stats = sk_stats[0] - -# if reduce == "macro" and ignore_index is not None and preds.shape[1]: -# sk_stats[ignore_index, :] = -1 - -# if reduce == "micro": -# _, fp, tn, _, _ = sk_stats -# else: -# _, fp, tn, _ = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] -# return fp, tn - - -# def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k=None, mdmc_reduce=None, stats=None): - -# if stats: -# fp, tn = stats -# else: -# stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k) -# fp, tn = stats - -# fp, tn = tensor(fp), tensor(tn) -# spec = _reduce_specificity( -# numerator=tn, -# denominator=tn + fp, -# weights=None if reduce != "weighted" else tn + fp, -# average=reduce, -# mdmc_average=mdmc_reduce, -# ) -# if reduce in [None, "none"] and ignore_index is not None and preds.shape[1] > 1: -# spec = spec.numpy() -# spec = np.insert(spec, ignore_index, math.nan) -# spec = tensor(spec) - -# return spec - - -# def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k=None): -# preds, target, _ = _input_format_classification( -# preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k -# ) - -# if mdmc_reduce == "global": -# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) -# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) -# return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) -# fp, tn = [], [] -# stats = [] - -# for i in range(preds.shape[0]): -# pred_i = preds[i, ...].T -# target_i = target[i, ...].T -# fp_i, tn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k) -# fp.append(fp_i) -# tn.append(tn_i) - -# stats.append(fp) -# stats.append(tn) -# return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats) - - -# @pytest.mark.parametrize("metric, fn_metric", [(Specificity, specificity)]) -# @pytest.mark.parametrize( -# "average, mdmc_average, num_classes, ignore_index, match_str", -# [ -# ("wrong", None, None, None, "`average`"), -# ("micro", "wrong", None, None, "`mdmc"), -# ("macro", None, None, None, "number of classes"), -# ("macro", None, 1, 0, "ignore_index"), -# ], -# ) -# def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): -# with pytest.raises(ValueError, match=match_str): -# metric( -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - -# with pytest.raises(ValueError, match=match_str): -# fn_metric( -# _input_binary.preds[0], -# _input_binary.target[0], -# average=average, -# mdmc_average=mdmc_average, -# num_classes=num_classes, -# ignore_index=ignore_index, -# ) - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -# def test_zero_division(metric_class, metric_fn): -# """Test that zero_division works correctly (currently should just set to 0).""" - -# preds = tensor([1, 2, 1, 1]) -# target = tensor([0, 0, 0, 0]) - -# cl_metric = metric_class(average="none", num_classes=3) -# cl_metric(preds, target) - -# result_cl = cl_metric.compute() -# result_fn = metric_fn(preds, target, average="none", num_classes=3) - -# assert result_cl[0] == result_fn[0] == 0 - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -# def test_no_support(metric_class, metric_fn): -# """This tests a rare edge case, where there is only one class present. - -# in target, and ignore_index is set to exactly that class - and the -# average method is equal to 'weighted'. - -# This would mean that the sum of weights equals zero, and would, without -# taking care of this case, return NaN. However, the reduction function -# should catch that and set the metric to equal the value of zero_division -# in this case (zero_division is for now not configurable and equals 0). -# """ - -# preds = tensor([1, 1, 0, 0]) -# target = tensor([0, 0, 0, 0]) - -# cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=1) -# cl_metric(preds, target) - -# result_cl = cl_metric.compute() -# result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=1) - -# assert result_cl == result_fn == 0 - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -# @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -# @pytest.mark.parametrize("ignore_index", [None, 0]) -# @pytest.mark.parametrize( -# "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", -# [ -# (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_spec), -# (_input_binary.preds, _input_binary.target, 1, False, None, _sk_spec), -# (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_spec), -# (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_spec), -# (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_spec), -# (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_spec), -# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), -# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), -# (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), -# ], -# ) -# class TestSpecificity(MetricTester): -# @pytest.mark.parametrize("ddp", [False, True]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# def test_specificity_class( -# self, -# ddp: bool, -# dist_sync_on_step: bool, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# # todo: `metric_fn` is unused -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=metric_class, -# sk_metric=partial( -# sk_wrapper, -# reduce=average, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# mdmc_reduce=mdmc_average, -# ), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - -# def test_specificity_fn( -# self, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): -# # todo: `metric_class` is unused -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=metric_fn, -# sk_metric=partial( -# sk_wrapper, -# reduce=average, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# mdmc_reduce=mdmc_average, -# ), -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - -# def test_accuracy_differentiability( -# self, -# preds: Tensor, -# target: Tensor, -# sk_wrapper: Callable, -# metric_class: Metric, -# metric_fn: Callable, -# multiclass: Optional[bool], -# num_classes: Optional[int], -# average: str, -# mdmc_average: Optional[str], -# ignore_index: Optional[int], -# ): - -# if num_classes == 1 and average != "micro": -# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if average == "weighted" and ignore_index is not None and mdmc_average is not None: -# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=metric_class, -# metric_functional=metric_fn, -# metric_args={ -# "num_classes": num_classes, -# "average": average, -# "threshold": THRESHOLD, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "mdmc_average": mdmc_average, -# }, -# ) - - -# _mc_k_target = tensor([0, 1, 2]) -# _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -# _ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -# _ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -# @pytest.mark.parametrize( -# "k, preds, target, average, expected_spec", -# [ -# (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), -# (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), -# (1, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 2)), -# (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6)), -# ], -# ) -# def test_top_k( -# metric_class, -# metric_fn, -# k: int, -# preds: Tensor, -# target: Tensor, -# average: str, -# expected_spec: Tensor, -# ): -# """A simple test to check that top_k works as expected. - -# Just a sanity check, the tests in Specificity should already guarantee the correctness of results. -# """ - -# class_metric = metric_class(top_k=k, average=average, num_classes=3) -# class_metric.update(preds, target) - -# assert torch.equal(class_metric.compute(), expected_spec) -# assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), expected_spec) - - -# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -# @pytest.mark.parametrize( -# "ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -# ) -# def test_class_not_present(metric_class, metric_fn, ignore_index, expected): -# """This tests that when metric is computed per class and a given class is not present in both the `preds` and -# `target`, the resulting score is `nan`.""" -# preds = torch.tensor([0, 0, 0]) -# target = torch.tensor([0, 0, 0]) -# num_classes = 2 - -# # test functional -# result_fn = metric_fn( -# preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index -# ) -# assert torch.allclose(expected, result_fn, equal_nan=True) - -# # test class -# cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) -# cl_metric(preds, target) -# result_cl = cl_metric.compute() -# assert torch.allclose(expected, result_cl, equal_nan=True) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 4b1cf2046e6..4331020adf5 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -479,305 +479,3 @@ def test_multilabel_stat_scores_dtype_gpu(self, input, dtype): metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, dtype=dtype, ) - - -# -------------------------- Old stuff -------------------------- - -# def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None): -# # todo: `mdmc_reduce` is unused -# preds, target, _ = _input_format_classification( -# preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k -# ) -# sk_preds, sk_target = preds.numpy(), target.numpy() - -# if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: -# sk_preds = np.delete(sk_preds, ignore_index, 1) -# sk_target = np.delete(sk_target, ignore_index, 1) - -# if preds.shape[1] == 1 and reduce == "samples": -# sk_target = sk_target.T -# sk_preds = sk_preds.T - -# sk_stats = multilabel_confusion_matrix( -# sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 -# ) - -# if preds.shape[1] == 1 and reduce != "samples": -# sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] -# else: -# sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] - -# if reduce == "micro": -# sk_stats = sk_stats.sum(axis=0, keepdims=True) - -# sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) - -# if reduce == "micro": -# sk_stats = sk_stats[0] - -# if reduce == "macro" and ignore_index is not None and preds.shape[1]: -# sk_stats[ignore_index, :] = -1 - -# return sk_stats - - -# def _sk_stat_scores_mdim_mcls( -# preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold -# ): -# preds, target, _ = _input_format_classification( -# preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k -# ) - -# if mdmc_reduce == "global": -# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) -# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - -# return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold) -# if mdmc_reduce == "samplewise": -# scores = [] - -# for i in range(preds.shape[0]): -# pred_i = preds[i, ...].T -# target_i = target[i, ...].T -# scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold) - -# scores.append(np.expand_dims(scores_i, 0)) - -# return np.concatenate(scores) - - -# @pytest.mark.parametrize( -# "reduce, mdmc_reduce, num_classes, inputs, ignore_index", -# [ -# ["unknown", None, None, _input_binary, None], -# ["micro", "unknown", None, _input_binary, None], -# ["macro", None, None, _input_binary, None], -# ["micro", None, None, _input_mdmc_prob, None], -# ["micro", None, None, _input_binary_prob, 0], -# ["micro", None, None, _input_mcls_prob, NUM_CLASSES], -# ["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES], -# ], -# ) -# def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): -# """Test a combination of parameters that are invalid and should raise an error. - -# This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when -# ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` -# when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. -# """ -# with pytest.raises(ValueError): -# stat_scores( -# inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index -# ) - -# with pytest.raises(ValueError): -# sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) -# sts(inputs.preds[0], inputs.target[0]) - - -# @pytest.mark.parametrize("ignore_index", [None, 0]) -# @pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) -# @pytest.mark.parametrize( -# "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold", -# [ -# (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0), -# (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5), -# (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5), -# (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), -# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5), -# (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), -# (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), -# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0), -# (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), -# ( -# _input_mdmc.preds, -# _input_mdmc.target, -# _sk_stat_scores_mdim_mcls, -# "samplewise", -# NUM_CLASSES, -# None, -# None, -# 0.0 -# ), -# ( -# _input_mdmc_prob.preds, -# _input_mdmc_prob.target, -# _sk_stat_scores_mdim_mcls, -# "samplewise", -# NUM_CLASSES, -# None, -# None, -# 0.0, -# ), -# (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0), -# ( -# _input_mdmc_prob.preds, -# _input_mdmc_prob.target, -# _sk_stat_scores_mdim_mcls, -# "global", -# NUM_CLASSES, -# None, -# None, -# 0.0, -# ), -# ], -# ) -# class TestStatScores(MetricTester): -# # DDP tests temporarily disabled due to hanging issues -# @pytest.mark.parametrize("ddp", [False]) -# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) -# @pytest.mark.parametrize("dtype", [torch.float, torch.double]) -# def test_stat_scores_class( -# self, -# ddp: bool, -# dist_sync_on_step: bool, -# dtype: torch.dtype, -# sk_fn: Callable, -# preds: Tensor, -# target: Tensor, -# reduce: str, -# mdmc_reduce: Optional[str], -# num_classes: Optional[int], -# multiclass: Optional[bool], -# ignore_index: Optional[int], -# top_k: Optional[int], -# threshold: Optional[float], -# ): -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# if preds.is_floating_point(): -# preds = preds.to(dtype) -# if target.is_floating_point(): -# target = target.to(dtype) - -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=StatScores, -# sk_metric=partial( -# sk_fn, -# reduce=reduce, -# mdmc_reduce=mdmc_reduce, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# top_k=top_k, -# threshold=threshold, -# ), -# dist_sync_on_step=dist_sync_on_step, -# metric_args={ -# "num_classes": num_classes, -# "reduce": reduce, -# "mdmc_reduce": mdmc_reduce, -# "threshold": threshold, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "top_k": top_k, -# }, -# ) - -# def test_stat_scores_fn( -# self, -# sk_fn: Callable, -# preds: Tensor, -# target: Tensor, -# reduce: str, -# mdmc_reduce: Optional[str], -# num_classes: Optional[int], -# multiclass: Optional[bool], -# ignore_index: Optional[int], -# top_k: Optional[int], -# threshold: Optional[float], -# ): -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# self.run_functional_metric_test( -# preds, -# target, -# metric_functional=stat_scores, -# sk_metric=partial( -# sk_fn, -# reduce=reduce, -# mdmc_reduce=mdmc_reduce, -# num_classes=num_classes, -# multiclass=multiclass, -# ignore_index=ignore_index, -# top_k=top_k, -# threshold=threshold, -# ), -# metric_args={ -# "num_classes": num_classes, -# "reduce": reduce, -# "mdmc_reduce": mdmc_reduce, -# "threshold": threshold, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "top_k": top_k, -# }, -# ) - -# def test_stat_scores_differentiability( -# self, -# sk_fn: Callable, -# preds: Tensor, -# target: Tensor, -# reduce: str, -# mdmc_reduce: Optional[str], -# num_classes: Optional[int], -# multiclass: Optional[bool], -# ignore_index: Optional[int], -# top_k: Optional[int], -# threshold: Optional[float], -# ): -# if ignore_index is not None and preds.ndim == 2: -# pytest.skip("Skipping ignore_index test with binary inputs.") - -# self.run_differentiability_test( -# preds, -# target, -# metric_module=StatScores, -# metric_functional=stat_scores, -# metric_args={ -# "num_classes": num_classes, -# "reduce": reduce, -# "mdmc_reduce": mdmc_reduce, -# "threshold": threshold, -# "multiclass": multiclass, -# "ignore_index": ignore_index, -# "top_k": top_k, -# }, -# ) - - -# _mc_k_target = tensor([0, 1, 2]) -# _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -# _ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -# _ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -# @pytest.mark.parametrize( -# "k, preds, target, reduce, expected", -# [ -# (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), -# (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), -# (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), -# (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), -# (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), -# (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), -# (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), -# (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), -# ], -# ) -# def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): -# """A simple test to check that top_k works as expected.""" - -# class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) -# class_metric.update(preds, target) - -# assert torch.equal(class_metric.compute(), expected.T) -# assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)