Skip to content

Commit

Permalink
Code cleaning after classification refactor 1/n (#1251)
Browse files Browse the repository at this point in the history
* cleanup
* remove some more
* fix broken doctest
* remove tests

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
SkafteNicki and pre-commit-ci[bot] authored Oct 12, 2022
1 parent eacfd2f commit 84ef54a
Show file tree
Hide file tree
Showing 37 changed files with 67 additions and 1,524 deletions.
18 changes: 16 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-

- Removed deprecated `BinnedAveragePrecision`, `BinnedPrecisionRecallCurve`, `RecallAtFixedPrecision` ([#1251](https://github.com/Lightning-AI/metrics/pull/1251))
- Removed deprecated `LabelRankingAveragePrecision`, `LabelRankingLoss` and `CoverageError` ([#1251](https://github.com/Lightning-AI/metrics/pull/1251))
- Removed deprecated `KLDivergence` and `AUC` ([#1251](https://github.com/Lightning-AI/metrics/pull/1251))

### Fixed

Expand Down Expand Up @@ -74,6 +75,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improved performance of retrieval metrics ([#1242](https://github.com/Lightning-AI/metrics/pull/1242))
- Changed `SSIM` and `MSSSIM` update to be online to reduce memory usage ([#1231](https://github.com/Lightning-AI/metrics/pull/1231))

### Deprecated

- Deprecated `BinnedAveragePrecision`, `BinnedPrecisionRecallCurve`, `BinnedRecallAtFixedPrecision` ([#1163](https://github.com/Lightning-AI/metrics/pull/1163))
* `BinnedAveragePrecision` -> use `AveragePrecision` with `thresholds` arg
* `BinnedPrecisionRecallCurve` -> use `AveragePrecisionRecallCurve` with `thresholds` arg
* `BinnedRecallAtFixedPrecision` -> use `RecallAtFixedPrecision` with `thresholds` arg
- Renamed and refactored `LabelRankingAveragePrecision`, `LabelRankingLoss` and `CoverageError` ([#1167](https://github.com/Lightning-AI/metrics/pull/1167))
* `LabelRankingAveragePrecision` -> `MultilabelRankingAveragePrecision`
* `LabelRankingLoss` -> `MultilabelRankingLoss`
* `CoverageError` -> `MultilabelCoverageError`
- Deprecated `KLDivergence` and `AUC` from classification package ([#1189](https://github.com/Lightning-AI/metrics/pull/1189))
* `KLDivergence` moved to `regression` package
* Instead of `AUC` use `torchmetrics.utils.compute.auc`

### Fixed

Expand Down
20 changes: 0 additions & 20 deletions docs/source/classification/auc.rst

This file was deleted.

14 changes: 0 additions & 14 deletions docs/source/classification/binned_average_precision.rst

This file was deleted.

14 changes: 0 additions & 14 deletions docs/source/classification/binned_precision_recall_curve.rst

This file was deleted.

14 changes: 0 additions & 14 deletions docs/source/classification/binned_recall_fixed_precision.rst

This file was deleted.

6 changes: 0 additions & 6 deletions docs/source/classification/coverage_error.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,11 @@ Coverage Error
Module Interface
________________

.. autoclass:: torchmetrics.CoverageError
:noindex:

.. autoclass:: torchmetrics.classification.MultilabelCoverageError
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.coverage_error
:noindex:

.. autofunction:: torchmetrics.functional.classification.multilabel_coverage_error
:noindex:
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,12 @@ Label Ranking Average Precision
Module Interface
________________

.. autoclass:: torchmetrics.LabelRankingAveragePrecision
:noindex:

.. autoclass:: torchmetrics.classification.MultilabelRankingAveragePrecision
:noindex:


Functional Interface
____________________

.. autofunction:: torchmetrics.functional.label_ranking_average_precision
:noindex:

.. autofunction:: torchmetrics.functional.classification.multilabel_ranking_average_precision
:noindex:
7 changes: 0 additions & 7 deletions docs/source/classification/label_ranking_loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,11 @@ Label Ranking Loss
Module Interface
________________

.. autoclass:: torchmetrics.LabelRankingLoss
:noindex:


.. autoclass:: torchmetrics.classification.MultilabelRankingLoss
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.label_ranking_loss
:noindex:

.. autofunction:: torchmetrics.functional.classification.multilabel_ranking_loss
:noindex:
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ Or directly from conda
pages/overview
pages/implement
pages/lightning
pages/classification
pages/retrieval

.. toctree::
Expand Down
101 changes: 0 additions & 101 deletions docs/source/pages/classification.rst

This file was deleted.

14 changes: 0 additions & 14 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,19 @@
SignalNoiseRatio,
)
from torchmetrics.classification import ( # noqa: E402
AUC,
AUROC,
ROC,
Accuracy,
AveragePrecision,
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
CalibrationError,
CohenKappa,
ConfusionMatrix,
CoverageError,
Dice,
F1Score,
FBetaScore,
HammingDistance,
HingeLoss,
JaccardIndex,
LabelRankingAveragePrecision,
LabelRankingLoss,
MatthewsCorrCoef,
Precision,
PrecisionRecallCurve,
Expand Down Expand Up @@ -113,12 +106,8 @@
__all__ = [
"functional",
"Accuracy",
"AUC",
"AUROC",
"AveragePrecision",
"BinnedAveragePrecision",
"BinnedPrecisionRecallCurve",
"BinnedRecallAtFixedPrecision",
"BLEUScore",
"BootStrapper",
"CalibrationError",
Expand All @@ -130,7 +119,6 @@
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
"CoverageError",
"Dice",
"TweedieDevianceScore",
"ErrorRelativeGlobalDimensionlessSynthesis",
Expand All @@ -142,8 +130,6 @@
"HingeLoss",
"JaccardIndex",
"KLDivergence",
"LabelRankingAveragePrecision",
"LabelRankingLoss",
"MatchErrorRate",
"MatthewsCorrCoef",
"MaxMetric",
Expand Down
10 changes: 0 additions & 10 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,14 @@
MultilabelStatScores,
StatScores,
)

from torchmetrics.classification.accuracy import Accuracy, BinaryAccuracy, MulticlassAccuracy, MultilabelAccuracy
from torchmetrics.classification.auc import AUC
from torchmetrics.classification.auroc import AUROC, BinaryAUROC, MulticlassAUROC, MultilabelAUROC
from torchmetrics.classification.average_precision import (
AveragePrecision,
BinaryAveragePrecision,
MulticlassAveragePrecision,
MultilabelAveragePrecision,
)
from torchmetrics.classification.binned_precision_recall import (
BinnedAveragePrecision,
BinnedPrecisionRecallCurve,
BinnedRecallAtFixedPrecision,
)
from torchmetrics.classification.calibration_error import (
BinaryCalibrationError,
CalibrationError,
Expand Down Expand Up @@ -92,9 +85,6 @@
Recall,
)
from torchmetrics.classification.ranking import (
CoverageError,
LabelRankingAveragePrecision,
LabelRankingLoss,
MultilabelCoverageError,
MultilabelRankingAveragePrecision,
MultilabelRankingLoss,
Expand Down
13 changes: 3 additions & 10 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,6 @@ class Accuracy(StatScores):
changed to subset accuracy (which requires all labels or sub-samples in the sample to
be correctly predicted) by setting ``subset_accuracy=True``.
Accepts all input types listed in :ref:`pages/classification:input types`.
Args:
num_classes:
Number of classes. Necessary for ``'macro'``, ``'weighted'`` and ``None`` average methods.
Expand Down Expand Up @@ -387,11 +385,10 @@ class Accuracy(StatScores):
- ``'samplewise'``: In this case, the statistics are computed separately for each
sample on the ``N`` axis, and then averaged over samples.
The computation for each sample is done by treating the flattened extra axes ``...``
(see :ref:`pages/classification:input types`) as the ``N`` dimension within the sample,
as the ``N`` dimension within the sample,
and computing the metric for the sample based on that.
- ``'global'``: In this case the ``N`` and ``...`` dimensions of the inputs
(see :ref:`pages/classification:input types`)
are flattened into a new ``N_X`` sample axis, i.e. the inputs are treated as if they
were ``(N_X, C)``. From here on the ``average`` parameter applies as usual.
Expand All @@ -409,9 +406,7 @@ class Accuracy(StatScores):
multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <pages/classification:using the multiclass parameter>`
for a more detailed explanation and examples.
than what they appear to be.
subset_accuracy:
Whether to compute subset accuracy for multi-label and multi-dimensional
Expand Down Expand Up @@ -557,9 +552,7 @@ def __init__(
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets. See
:ref:`pages/classification:input types` for more information on input
types.
"""Update state with predictions and targets.
Args:
preds: Predictions from model (logits, probabilities, or labels)
Expand Down
Loading

0 comments on commit 84ef54a

Please sign in to comment.