Skip to content

Commit

Permalink
[Refactor] Classification 3/n (#1145)
Browse files Browse the repository at this point in the history
* fbeta init structure

* working fbeta binary, multilabel

* precision and recall functional

* refactor testing

* working fbeta in all cases

* working precision recall with tests

* formatting

* docs

* init files

* fix stat score docstring

* docstrings

* flake8

* add link

* fix integer division

* Apply suggestions from code review

* Apply suggestions from code review

* docs

* fix docs?

* naming mistake

* remove duplicate

* fix docs

* try fixing tests

* docs

* fix tests

* docs

* docs

* try again

* try again

* again

* again

* please fix

* please solve

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
4 people committed Sep 12, 2022
1 parent f499910 commit 2179495
Show file tree
Hide file tree
Showing 24 changed files with 4,661 additions and 833 deletions.
50 changes: 46 additions & 4 deletions docs/source/classification/f1_score.rst
Original file line number Diff line number Diff line change
@@ -1,20 +1,62 @@
.. customcarditem::
:header: F1 Score
:header: F-1 Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

########
F1 Score
########
#########
F-1 Score
#########

Module Interface
________________

F1Score
^^^^^^^

.. autoclass:: torchmetrics.F1Score
:noindex:

BinaryF1Score
^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryF1Score
:noindex:

MulticlassF1Score
^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassF1Score
:noindex:

MultilabelF1Score
^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelF1Score
:noindex:

Functional Interface
____________________

f1_score
^^^^^^^^

.. autofunction:: torchmetrics.functional.f1_score
:noindex:

binary_f1_score
^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_f1_score
:noindex:

multiclass_f1_score
^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_f1_score
:noindex:

multilabel_f1_score
^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_f1_score
:noindex:
50 changes: 46 additions & 4 deletions docs/source/classification/fbeta_score.rst
Original file line number Diff line number Diff line change
@@ -1,22 +1,64 @@
.. customcarditem::
:header: FBeta Score
:header: F-Beta Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

###########
FBeta Score
###########
############
F-Beta Score
############

Module Interface
________________

FBetaScore
^^^^^^^^^^

.. autoclass:: torchmetrics.FBetaScore
:noindex:

BinaryFBetaScore
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryFBetaScore
:noindex:

MulticlassFBetaScore
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassFBetaScore
:noindex:

MultilabelFBetaScore
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelFBetaScore
:noindex:

Functional Interface
____________________

fbeta_score
^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.fbeta_score
:noindex:

binary_fbeta_score
^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_fbeta_score
:noindex:

multiclass_fbeta_score
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_fbeta_score
:noindex:

multilabel_fbeta_score
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_fbeta_score
:noindex:
38 changes: 38 additions & 0 deletions docs/source/classification/precision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

#########
Precision
#########
Expand All @@ -13,8 +15,44 @@ ________________
.. autoclass:: torchmetrics.Precision
:noindex:

BinaryPrecision
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryPrecision
:noindex:

MulticlassPrecision
^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassPrecision
:noindex:

MultilabelPrecision
^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelPrecision
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.precision
:noindex:

binary_precision
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_precision
:noindex:

multiclass_precision
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_precision
:noindex:

multilabel_precision
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_precision
:noindex:
36 changes: 36 additions & 0 deletions docs/source/classification/recall.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.Recall
:noindex:

BinaryRecall
^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryRecall
:noindex:

MulticlassRecall
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassRecall
:noindex:

MultilabelRecall
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelRecall
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.recall
:noindex:

binary_recall
^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_recall
:noindex:

multiclass_recall
^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_recall
:noindex:

multilabel_recall
^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_recall
:noindex:
24 changes: 24 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@
AveragePrecision,
BinaryCohenKappa,
BinaryConfusionMatrix,
BinaryF1Score,
BinaryFBetaScore,
BinaryJaccardIndex,
BinaryMatthewsCorrCoef,
BinaryPrecision,
BinaryRecall,
BinaryStatScores,
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
Expand All @@ -50,12 +54,20 @@
MatthewsCorrCoef,
MulticlassCohenKappa,
MulticlassConfusionMatrix,
MulticlassF1Score,
MulticlassFBetaScore,
MulticlassJaccardIndex,
MulticlassMatthewsCorrCoef,
MulticlassPrecision,
MulticlassRecall,
MulticlassStatScores,
MultilabelConfusionMatrix,
MultilabelF1Score,
MultilabelFBetaScore,
MultilabelJaccardIndex,
MultilabelMatthewsCorrCoef,
MultilabelPrecision,
MultilabelRecall,
MultilabelStatScores,
Precision,
PrecisionRecallCurve,
Expand Down Expand Up @@ -153,7 +165,13 @@
"ExplainedVariance",
"ExtendedEditDistance",
"F1Score",
"BinaryF1Score",
"MulticlassF1Score",
"MultilabelF1Score",
"FBetaScore",
"BinaryFBetaScore",
"MulticlassFBetaScore",
"MultilabelFBetaScore",
"HammingDistance",
"HingeLoss",
"JaccardIndex",
Expand Down Expand Up @@ -185,10 +203,16 @@
"PermutationInvariantTraining",
"Perplexity",
"Precision",
"BinaryPrecision",
"MulticlassPrecision",
"MultilabelPrecision",
"PrecisionRecallCurve",
"PeakSignalNoiseRatio",
"R2Score",
"Recall",
"BinaryRecall",
"MulticlassRecall",
"MultilabelRecall",
"RetrievalFallOut",
"RetrievalHitRate",
"RetrievalMAP",
Expand Down
22 changes: 20 additions & 2 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401
from torchmetrics.classification.f_beta import ( # noqa: F401
BinaryF1Score,
BinaryFBetaScore,
F1Score,
FBetaScore,
MulticlassF1Score,
MulticlassFBetaScore,
MultilabelF1Score,
MultilabelFBetaScore,
)
from torchmetrics.classification.hamming import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import HingeLoss # noqa: F401
from torchmetrics.classification.jaccard import ( # noqa: F401
Expand All @@ -50,7 +59,16 @@
MulticlassMatthewsCorrCoef,
MultilabelMatthewsCorrCoef,
)
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall import ( # noqa: F401
BinaryPrecision,
BinaryRecall,
MulticlassPrecision,
MulticlassRecall,
MultilabelPrecision,
MultilabelRecall,
Precision,
Recall,
)
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.ranking import ( # noqa: F401
CoverageError,
Expand Down
Loading

0 comments on commit 2179495

Please sign in to comment.