Skip to content

Commit

Permalink
[Refactor] Classification 1/n (#1054)
Browse files Browse the repository at this point in the history
* base structure

* bincount

* binary

* files

* stat score

* multiclass + multilabel confmat

* update

* stat_score

* change bincount

* move back

* del tests

* rest of structure

* confmat working

* working binary stat scores

* full testing

* update

* update

* add missing tests

* update

* multilabel stat scores

* disable old testing

* more testing

* flaky tests

* changelog

* refactor

* fixes

* typing

* update

* fix tests

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* Update src/torchmetrics/functional/classification/confusion_matrix.py

* missing literal

* add docstring to functional confusion matrix

* add docstring to modular confusion matrix

* add docstring to functional stat scores

* add docstring to modular stat scores

* make private

* docs

* fix mypy and doctests

* fix docs formatting

* literal backwards

* custom movedim

* debug

* debug

* fix tests

* fix tests

* fix tests

* fix tests

* add some testing

* fix tests

* fix docstring

* fix tests

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
5 people committed Sep 13, 2022
1 parent 91cd625 commit b7120e3
Show file tree
Hide file tree
Showing 21 changed files with 3,561 additions and 439 deletions.
2 changes: 1 addition & 1 deletion .github/actions/unittesting/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ runs:

- name: Unittests
working-directory: ./tests
run: python -m pytest ${{ inputs.dirs }} --cov=torchmetrics --junitxml="$PYTEST_ARTEFACT.xml" --durations=50 ${{ inputs.test-timeout }}
run: python -m pytest -v --maxfail=5 ${{ inputs.dirs }} --cov=torchmetrics --junitxml="$PYTEST_ARTEFACT.xml" --durations=50 ${{ inputs.test-timeout }}
shell: ${{ inputs.shell-type }}

- name: Upload pytest results
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Classification refactor (
[#1054](https://github.com/Lightning-AI/metrics/pull/1054),
)

- Changed update in `FID` metric to be done in a online fashion to save memory ([#1199](https://github.com/PyTorchLightning/metrics/pull/1199))


Expand Down
45 changes: 45 additions & 0 deletions docs/source/classification/confusion_matrix.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,56 @@ Confusion Matrix
Module Interface
________________

ConfusionMatrix
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.ConfusionMatrix
:noindex:

BinaryConfusionMatrix
^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryConfusionMatrix
:noindex:
:exclude-members: update, compute

MulticlassConfusionMatrix
^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassConfusionMatrix
:noindex:
:exclude-members: update, compute

MultilabelConfusionMatrix
^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelConfusionMatrix
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

confusion_matrix
^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.confusion_matrix
:noindex:

binary_confusion_matrix
^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_confusion_matrix
:noindex:

multiclass_confusion_matrix
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_confusion_matrix
:noindex:

multilabel_confusion_matrix
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_confusion_matrix
:noindex:
45 changes: 45 additions & 0 deletions docs/source/classification/stat_scores.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,56 @@ Stat Scores
Module Interface
________________

StatScores
^^^^^^^^^^

.. autoclass:: torchmetrics.StatScores
:noindex:

BinaryStatScores
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryStatScores
:noindex:
:exclude-members: update, compute

MulticlassStatScores
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassStatScores
:noindex:
:exclude-members: update, compute

MultilabelStatScores
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MultilabelStatScores
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

stat_scores
^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.stat_scores
:noindex:

binary_stat_scores
^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_stat_scores
:noindex:

multiclass_stat_scores
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_stat_scores
:noindex:

multilabel_stat_scores
^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multilabel_stat_scores
:noindex:
12 changes: 12 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ROC,
Accuracy,
AveragePrecision,
BinaryConfusionMatrix,
BinaryStatScores,
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
Expand All @@ -43,6 +45,10 @@
LabelRankingAveragePrecision,
LabelRankingLoss,
MatthewsCorrCoef,
MulticlassConfusionMatrix,
MulticlassStatScores,
MultilabelConfusionMatrix,
MultilabelStatScores,
Precision,
PrecisionRecallCurve,
Recall,
Expand Down Expand Up @@ -126,6 +132,9 @@
"CHRFScore",
"CohenKappa",
"ConfusionMatrix",
"BinaryConfusionMatrix",
"MulticlassConfusionMatrix",
"MultilabelConfusionMatrix",
"CosineSimilarity",
"CoverageError",
"Dice",
Expand Down Expand Up @@ -187,6 +196,9 @@
"SQuAD",
"StructuralSimilarityIndexMeasure",
"StatScores",
"BinaryStatScores",
"MulticlassStatScores",
"MultilabelStatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"TranslationEditRate",
Expand Down
14 changes: 12 additions & 2 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401
from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401
from torchmetrics.classification.confusion_matrix import ( # noqa: F401
BinaryConfusionMatrix,
ConfusionMatrix,
MulticlassConfusionMatrix,
MultilabelConfusionMatrix,
)
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401
from torchmetrics.classification.hamming import HammingDistance # noqa: F401
Expand All @@ -37,4 +42,9 @@
)
from torchmetrics.classification.roc import ROC # noqa: F401
from torchmetrics.classification.specificity import Specificity # noqa: F401
from torchmetrics.classification.stat_scores import StatScores # noqa: F401
from torchmetrics.classification.stat_scores import ( # noqa: F401
BinaryStatScores,
MulticlassStatScores,
MultilabelStatScores,
StatScores,
)
Loading

0 comments on commit b7120e3

Please sign in to comment.