Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Classification 6/n #1163

Merged
merged 48 commits into from
Aug 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
02f5dd1
some idea
SkafteNicki Jul 21, 2022
353627d
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Jul 22, 2022
fae3ae7
implementation stuff
SkafteNicki Jul 24, 2022
00368bb
error message
SkafteNicki Jul 25, 2022
82fcd9b
somethings working
SkafteNicki Jul 25, 2022
4ba7218
working
SkafteNicki Jul 25, 2022
dc29a83
working multilabel precision recall
SkafteNicki Jul 26, 2022
cb35003
docstrings
SkafteNicki Jul 26, 2022
67abff7
working average precision
SkafteNicki Jul 26, 2022
42a328c
working roc
SkafteNicki Jul 26, 2022
1754964
working auroc
SkafteNicki Jul 27, 2022
1c29cb8
working precision at recall
SkafteNicki Jul 27, 2022
481c03d
docs
SkafteNicki Jul 27, 2022
bfdf0c8
init files
SkafteNicki Jul 27, 2022
bdb0746
beginning doctest
SkafteNicki Jul 28, 2022
518a33b
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Jul 28, 2022
614bbb1
more docs
SkafteNicki Jul 28, 2022
f2a4f80
more docs
SkafteNicki Jul 28, 2022
278f9ca
more docs
SkafteNicki Jul 28, 2022
51560df
more docs
SkafteNicki Jul 28, 2022
9c95fa0
more docs
SkafteNicki Jul 28, 2022
1653a97
more docs
SkafteNicki Jul 28, 2022
c4eac51
correction to math
SkafteNicki Jul 28, 2022
f160e95
fix mypy
SkafteNicki Jul 28, 2022
0b5d4ba
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Aug 1, 2022
c112b65
add suggestions
SkafteNicki Aug 1, 2022
0c8df91
change default from 100 to None
SkafteNicki Aug 1, 2022
11bfdff
fix
SkafteNicki Aug 3, 2022
267e5c8
fix
SkafteNicki Aug 3, 2022
abef85e
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Aug 3, 2022
110020b
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Aug 5, 2022
dfa05ea
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Aug 6, 2022
c260037
some fixes
SkafteNicki Aug 6, 2022
8898f7b
suggestions for stancld
SkafteNicki Aug 15, 2022
7f66689
try fixing
SkafteNicki Aug 15, 2022
b0ad98a
try fix
SkafteNicki Aug 15, 2022
a672a37
try again
SkafteNicki Aug 15, 2022
b7779de
more fixing
SkafteNicki Aug 15, 2022
c366f4c
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Aug 15, 2022
fa7c2f9
another fix
SkafteNicki Aug 15, 2022
88ea62c
skip half + cpu test for old versions
SkafteNicki Aug 15, 2022
fc5c2f0
fix link
SkafteNicki Aug 15, 2022
e356630
Merge branch 'devel/classification' into refactor/classification_6
SkafteNicki Aug 17, 2022
c383a54
another fix
SkafteNicki Aug 17, 2022
8f16941
nan safety
SkafteNicki Aug 17, 2022
ed8223e
another fix
SkafteNicki Aug 17, 2022
961cf9f
skip non working cpu half tests
SkafteNicki Aug 18, 2022
ab1d4ed
skip non working cpu + half tests
SkafteNicki Aug 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/source/classification/auroc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,44 @@ ________________
.. autoclass:: torchmetrics.AUROC
:noindex:

BinaryAUROC
^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryAUROC
:noindex:

MulticlassAUROC
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassAUROC
:noindex:

MultilabelAUROC
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelAUROC
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.auroc
:noindex:

binary_auroc
^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_auroc
:noindex:

multiclass_auroc
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_auroc
:noindex:

multilabel_auroc
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_auroc
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/average_precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.AveragePrecision
:noindex:

BinaryAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryAveragePrecision
:noindex:

MulticlassAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassAveragePrecision
:noindex:

MultilabelAveragePrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelAveragePrecision
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.average_precision
:noindex:

binary_average_precision
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_average_precision
:noindex:

multiclass_average_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_average_precision
:noindex:

multilabel_average_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_average_precision
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/precision_recall_curve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.PrecisionRecallCurve
:noindex:

BinaryPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryPrecisionRecallCurve
:noindex:

MulticlassPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassPrecisionRecallCurve
:noindex:

MultilabelPrecisionRecallCurve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelPrecisionRecallCurve
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.precision_recall_curve
:noindex:

binary_precision_recall_curve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_precision_recall_curve
:noindex:

multiclass_precision_recall_curve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_precision_recall_curve
:noindex:

multilabel_precision_recall_curve
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_precision_recall_curve
:noindex:
50 changes: 50 additions & 0 deletions docs/source/classification/recall_at_fixed_precision.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
.. customcarditem::
:header: Recall At Fixed Precision
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

#########################
Recall At Fixed Precision
#########################

Module Interface
________________

BinaryRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryRecallAtFixedPrecision
:noindex:

MulticlassRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassRecallAtFixedPrecision
:noindex:

MultilabelRecallAtFixedPrecision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelRecallAtFixedPrecision
:noindex:

Functional Interface
____________________

binary_recall_at_fixed_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_recall_at_fixed_precision
:noindex:

multiclass_recall_at_fixed_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_recall_at_fixed_precision
:noindex:

multilabel_recall_at_fixed_precision
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_recall_at_fixed_precision
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/roc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.ROC
:noindex:

BinaryROC
^^^^^^^^^

.. autoclass:: torchmetrics.BinaryROC
:noindex:

MulticlassROC
^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassROC
:noindex:

MultilabelROC
^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelROC
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.roc
:noindex:

binary_roc
^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_roc
:noindex:

multiclass_roc
^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_roc
:noindex:

multilabel_roc
^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_roc
:noindex:
30 changes: 30 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
Accuracy,
AveragePrecision,
BinaryAccuracy,
BinaryAUROC,
BinaryAveragePrecision,
BinaryCohenKappa,
BinaryConfusionMatrix,
BinaryF1Score,
Expand All @@ -35,7 +37,10 @@
BinaryJaccardIndex,
BinaryMatthewsCorrCoef,
BinaryPrecision,
BinaryPrecisionRecallCurve,
BinaryRecall,
BinaryRecallAtFixedPrecision,
BinaryROC,
BinarySpecificity,
BinaryStatScores,
BinnedAveragePrecision,
Expand All @@ -56,6 +61,8 @@
LabelRankingLoss,
MatthewsCorrCoef,
MulticlassAccuracy,
MulticlassAUROC,
MulticlassAveragePrecision,
MulticlassCohenKappa,
MulticlassConfusionMatrix,
MulticlassF1Score,
Expand All @@ -64,10 +71,15 @@
MulticlassJaccardIndex,
MulticlassMatthewsCorrCoef,
MulticlassPrecision,
MulticlassPrecisionRecallCurve,
MulticlassRecall,
MulticlassRecallAtFixedPrecision,
MulticlassROC,
MulticlassSpecificity,
MulticlassStatScores,
MultilabelAccuracy,
MultilabelAUROC,
MultilabelAveragePrecision,
MultilabelConfusionMatrix,
MultilabelCoverageError,
MultilabelExactMatch,
Expand All @@ -77,9 +89,12 @@
MultilabelJaccardIndex,
MultilabelMatthewsCorrCoef,
MultilabelPrecision,
MultilabelPrecisionRecallCurve,
MultilabelRankingAveragePrecision,
MultilabelRankingLoss,
MultilabelRecall,
MultilabelRecallAtFixedPrecision,
MultilabelROC,
MultilabelSpecificity,
MultilabelStatScores,
Precision,
Expand Down Expand Up @@ -155,6 +170,21 @@
"MultilabelAccuracy",
"AUC",
"AUROC",
"BinaryAUROC",
"BinaryAveragePrecision",
"BinaryPrecisionRecallCurve",
"BinaryRecallAtFixedPrecision",
"BinaryROC",
"MultilabelROC",
"MulticlassAUROC",
"MulticlassAveragePrecision",
"MulticlassPrecisionRecallCurve",
"MulticlassRecallAtFixedPrecision",
"MulticlassROC",
"MultilabelAUROC",
"MultilabelAveragePrecision",
"MultilabelPrecisionRecallCurve",
"MultilabelRecallAtFixedPrecision",
"AveragePrecision",
"BinnedAveragePrecision",
"BinnedPrecisionRecallCurve",
Expand Down
23 changes: 19 additions & 4 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
MulticlassConfusionMatrix,
MultilabelConfusionMatrix,
)
from torchmetrics.classification.precision_recall_curve import ( # noqa: F401 isort:skip
PrecisionRecallCurve,
BinaryPrecisionRecallCurve,
MulticlassPrecisionRecallCurve,
MultilabelPrecisionRecallCurve,
)
from torchmetrics.classification.stat_scores import ( # noqa: F401 isort:skip
BinaryStatScores,
MulticlassStatScores,
Expand All @@ -31,8 +37,13 @@
MultilabelAccuracy,
)
from torchmetrics.classification.auc import AUC # noqa: F401
from torchmetrics.classification.auroc import AUROC # noqa: F401
from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC # noqa: F401
from torchmetrics.classification.average_precision import ( # noqa: F401
AveragePrecision,
BinaryAveragePrecision,
MulticlassAveragePrecision,
MultilabelAveragePrecision,
)
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
Expand Down Expand Up @@ -80,7 +91,6 @@
Precision,
Recall,
)
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.ranking import ( # noqa: F401
CoverageError,
LabelRankingAveragePrecision,
Expand All @@ -89,7 +99,12 @@
MultilabelRankingAveragePrecision,
MultilabelRankingLoss,
)
from torchmetrics.classification.roc import ROC # noqa: F401
from torchmetrics.classification.recall_at_fixed_precision import ( # noqa: F401
BinaryRecallAtFixedPrecision,
MulticlassRecallAtFixedPrecision,
MultilabelRecallAtFixedPrecision,
)
from torchmetrics.classification.roc import ROC, BinaryROC, MulticlassROC, MultilabelROC # noqa: F401
from torchmetrics.classification.specificity import ( # noqa: F401
BinarySpecificity,
MulticlassSpecificity,
Expand Down
Loading