Skip to content

Commit

Permalink
[Refactor] Classification 5/n (#1159)
Browse files Browse the repository at this point in the history
* base structure

* partly working accuracy

* init files

* working accuracy

* exact match

* mix mistake

* exact matching

* init files

* fix doctest

* mypy

* fix docs

* fix integer division

* try fixing test dependency

* fix f-string

* fix integer division

* fix integer division

* fix?

* try again

* fix pep8

* docs

* try something

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 26, 2022
1 parent b2707ce commit 2901f7a
Show file tree
Hide file tree
Showing 17 changed files with 2,048 additions and 401 deletions.
36 changes: 36 additions & 0 deletions docs/source/classification/accuracy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,44 @@ ________________
.. autoclass:: torchmetrics.Accuracy
:noindex:

BinaryAccuracy
^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryAccuracy
:noindex:

MulticlassAccuracy
^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassAccuracy
:noindex:

MultilabelAccuracy
^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelAccuracy
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.accuracy
:noindex:

binary_accuracy
^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_accuracy
:noindex:

multiclass_accuracy
^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_accuracy
:noindex:

multilabel_accuracy
^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_accuracy
:noindex:
26 changes: 26 additions & 0 deletions docs/source/classification/exact_match.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
.. customcarditem::
:header: Exact Match
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

###########
Exact Match
###########

Module Interface
________________

MultilabelExactMatch
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelExactMatch
:noindex:

Functional Interface
____________________

multilabel_exact_match
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_exact_match
:noindex:
8 changes: 8 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ROC,
Accuracy,
AveragePrecision,
BinaryAccuracy,
BinaryCohenKappa,
BinaryConfusionMatrix,
BinaryF1Score,
Expand Down Expand Up @@ -54,6 +55,7 @@
LabelRankingAveragePrecision,
LabelRankingLoss,
MatthewsCorrCoef,
MulticlassAccuracy,
MulticlassCohenKappa,
MulticlassConfusionMatrix,
MulticlassF1Score,
Expand All @@ -65,7 +67,9 @@
MulticlassRecall,
MulticlassSpecificity,
MulticlassStatScores,
MultilabelAccuracy,
MultilabelConfusionMatrix,
MultilabelExactMatch,
MultilabelF1Score,
MultilabelFBetaScore,
MultilabelHammingDistance,
Expand Down Expand Up @@ -143,6 +147,9 @@
__all__ = [
"functional",
"Accuracy",
"BinaryAccuracy",
"MulticlassAccuracy",
"MultilabelAccuracy",
"AUC",
"AUROC",
"AveragePrecision",
Expand Down Expand Up @@ -187,6 +194,7 @@
"BinaryJaccardIndex",
"MulticlassJaccardIndex",
"MultilabelJaccardIndex",
"MultilabelExactMatch",
"KLDivergence",
"LabelRankingAveragePrecision",
"LabelRankingLoss",
Expand Down
8 changes: 7 additions & 1 deletion src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
StatScores,
)

from torchmetrics.classification.accuracy import Accuracy # noqa: F401
from torchmetrics.classification.accuracy import ( # noqa: F401
Accuracy,
BinaryAccuracy,
MulticlassAccuracy,
MultilabelAccuracy,
)
from torchmetrics.classification.auc import AUC # noqa: F401
from torchmetrics.classification.auroc import AUROC # noqa: F401
from torchmetrics.classification.avg_precision import AveragePrecision # noqa: F401
Expand All @@ -34,6 +39,7 @@
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.exact_match import MultilabelExactMatch # noqa: F401
from torchmetrics.classification.f_beta import ( # noqa: F401
BinaryF1Score,
BinaryFBetaScore,
Expand Down
Loading

0 comments on commit 2901f7a

Please sign in to comment.