Skip to content

Commit

Permalink
[Refactor] Classification 2/n (#1143)
Browse files Browse the repository at this point in the history
* base structure

* bincount

* binary

* files

* stat score

* multiclass + multilabel confmat

* update

* stat_score

* change bincount

* move back

* del tests

* rest of structure

* confmat working

* working binary stat scores

* full testing

* update

* update

* add missing tests

* update

* multilabel stat scores

* disable old testing

* more testing

* flaky tests

* changelog

* refactor

* class interface

* fixes

* typing

* update

* fix tests

* Update src/torchmetrics/functional/classification/confusion_matrix.py

* missing literal

* add docstring to functional confusion matrix

* add docstring to modular confusion matrix

* add docstring to functional stat scores

* add docstring to modular stat scores

* make private

* docs

* fix mypy and doctests

* fix docs formatting

* literal backwards

* custom movedim

* docs

* inits

* fix

* fix mistakes

* working cohen kappa

* Delete speed.py

* cohen kappa done

* working matthews

* working jaccard

* docstrings for jaccard

* small improve

* typing

* fix doctest

* try something

* Apply suggestions from code review

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
5 people committed Aug 26, 2022
1 parent f06e41e commit ed34f7f
Show file tree
Hide file tree
Showing 22 changed files with 2,696 additions and 410 deletions.
32 changes: 32 additions & 0 deletions docs/source/classification/cohen_kappa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,43 @@ Cohen Kappa
Module Interface
________________

CohenKappa
^^^^^^^^^^

.. autoclass:: torchmetrics.CohenKappa
:noindex:

BinaryCohenKappa
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryCohenKappa
:noindex:
:exclude-members: update, compute

MulticlassCohenKappa
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassCohenKappa
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

cohen_kappa
^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.cohen_kappa
:noindex:

binary_cohen_kappa
^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_cohen_kappa
:noindex:

multiclass_cohen_kappa
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_cohen_kappa
:noindex:
46 changes: 46 additions & 0 deletions docs/source/classification/jaccard_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,57 @@ Jaccard Index
Module Interface
________________

CohenKappa
^^^^^^^^^^

.. autoclass:: torchmetrics.JaccardIndex
:noindex:

BinaryJaccardIndex
^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryJaccardIndex
:noindex:
:exclude-members: update, compute

MulticlassJaccardIndex
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassJaccardIndex
:noindex:
:exclude-members: update, compute

MultilabelJaccardIndex
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelJaccardIndex
:noindex:
:exclude-members: update, compute


Functional Interface
____________________

jaccard_index
^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.jaccard_index
:noindex:

binary_jaccard_index
^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_jaccard_index
:noindex:

multiclass_jaccard_index
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_jaccard_index
:noindex:

multilabel_jaccard_index
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_jaccard_index
:noindex:
52 changes: 49 additions & 3 deletions docs/source/classification/matthews_corr_coef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,64 @@

.. include:: ../links.rst

####################
Matthews Corr. Coef.
####################
################################
Matthews Correlation Coefficient
################################

Module Interface
________________

MatthewsCorrCoef
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MatthewsCorrCoef
:noindex:

BinaryMatthewsCorrCoef
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryMatthewsCorrCoef
:noindex:
:exclude-members: update, compute

MulticlassMatthewsCorrCoef
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassMatthewsCorrCoef
:noindex:
:exclude-members: update, compute

MultilabelMatthewsCorrCoef
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelMatthewsCorrCoef
:noindex:
:exclude-members: update, compute


Functional Interface
____________________

matthews_corrcoef
^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.matthews_corrcoef
:noindex:

binary_matthews_corrcoef
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_matthews_corrcoef
:noindex:

multiclass_matthews_corrcoef
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_matthews_corrcoef
:noindex:

multilabel_matthews_corrcoef
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_matthews_corrcoef
:noindex:
16 changes: 16 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
ROC,
Accuracy,
AveragePrecision,
BinaryCohenKappa,
BinaryConfusionMatrix,
BinaryJaccardIndex,
BinaryMatthewsCorrCoef,
BinaryStatScores,
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
Expand All @@ -45,9 +48,14 @@
LabelRankingAveragePrecision,
LabelRankingLoss,
MatthewsCorrCoef,
MulticlassCohenKappa,
MulticlassConfusionMatrix,
MulticlassJaccardIndex,
MulticlassMatthewsCorrCoef,
MulticlassStatScores,
MultilabelConfusionMatrix,
MultilabelJaccardIndex,
MultilabelMatthewsCorrCoef,
MultilabelStatScores,
Precision,
PrecisionRecallCurve,
Expand Down Expand Up @@ -131,6 +139,8 @@
"CharErrorRate",
"CHRFScore",
"CohenKappa",
"BinaryCohenKappa",
"MulticlassCohenKappa",
"ConfusionMatrix",
"BinaryConfusionMatrix",
"MulticlassConfusionMatrix",
Expand All @@ -147,11 +157,17 @@
"HammingDistance",
"HingeLoss",
"JaccardIndex",
"BinaryJaccardIndex",
"MulticlassJaccardIndex",
"MultilabelJaccardIndex",
"KLDivergence",
"LabelRankingAveragePrecision",
"LabelRankingLoss",
"MatchErrorRate",
"MatthewsCorrCoef",
"BinaryMatthewsCorrCoef",
"MulticlassMatthewsCorrCoef",
"MultilabelMatthewsCorrCoef",
"MaxMetric",
"MeanAbsoluteError",
"MeanAbsolutePercentageError",
Expand Down
41 changes: 26 additions & 15 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.classification.confusion_matrix import ( # noqa: F401 isort:skip
BinaryConfusionMatrix,
ConfusionMatrix,
MulticlassConfusionMatrix,
MultilabelConfusionMatrix,
)
from torchmetrics.classification.stat_scores import ( # noqa: F401 isort:skip
BinaryStatScores,
MulticlassStatScores,
MultilabelStatScores,
StatScores,
)

from torchmetrics.classification.accuracy import Accuracy # noqa: F401
from torchmetrics.classification.auc import AUC # noqa: F401
from torchmetrics.classification.auroc import AUROC # noqa: F401
Expand All @@ -19,20 +32,24 @@
from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ( # noqa: F401
BinaryConfusionMatrix,
ConfusionMatrix,
MulticlassConfusionMatrix,
MultilabelConfusionMatrix,
)
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401
from torchmetrics.classification.hamming import HammingDistance # noqa: F401
from torchmetrics.classification.hinge import HingeLoss # noqa: F401
from torchmetrics.classification.jaccard import JaccardIndex # noqa: F401
from torchmetrics.classification.jaccard import ( # noqa: F401
BinaryJaccardIndex,
JaccardIndex,
MulticlassJaccardIndex,
MultilabelJaccardIndex,
)
from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401
from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef # noqa: F401
from torchmetrics.classification.matthews_corrcoef import ( # noqa: F401
BinaryMatthewsCorrCoef,
MatthewsCorrCoef,
MulticlassMatthewsCorrCoef,
MultilabelMatthewsCorrCoef,
)
from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401
from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from torchmetrics.classification.ranking import ( # noqa: F401
Expand All @@ -42,9 +59,3 @@
)
from torchmetrics.classification.roc import ROC # noqa: F401
from torchmetrics.classification.specificity import Specificity # noqa: F401
from torchmetrics.classification.stat_scores import ( # noqa: F401
BinaryStatScores,
MulticlassStatScores,
MultilabelStatScores,
StatScores,
)
Loading

0 comments on commit ed34f7f

Please sign in to comment.