Skip to content

Commit

Permalink
[Refactor] Classification 9/n (#1189)
Browse files Browse the repository at this point in the history
* micro to macro

* new kl

* deprecate old

* move auc

* move tests around

* change import paths

* update test path

* change ref

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* update doctests

* doctests fix

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
SkafteNicki and justusschock authored Aug 26, 2022
1 parent 031ffdd commit 033e458
Show file tree
Hide file tree
Showing 31 changed files with 446 additions and 324 deletions.
2 changes: 1 addition & 1 deletion docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ the following limitations:

- :ref:`image/peak_signal_noise_ratio:Peak Signal-to-Noise Ratio (PSNR)`
- :ref:`image/structural_similarity:Structural Similarity Index Measure (SSIM)`
- :ref:`classification/kl_divergence:KL Divergence`
- :ref:`regression/kl_divergence:KL Divergence`

You can always check the precision/dtype of the metric by checking the `.dtype` property.

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
HammingDistance,
HingeLoss,
JaccardIndex,
KLDivergence,
LabelRankingAveragePrecision,
LabelRankingLoss,
MatthewsCorrCoef,
Expand Down Expand Up @@ -121,6 +120,7 @@
from torchmetrics.regression import ( # noqa: E402
CosineSimilarity,
ExplainedVariance,
KLDivergence,
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredError,
Expand Down
1 change: 0 additions & 1 deletion src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
MulticlassJaccardIndex,
MultilabelJaccardIndex,
)
from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401
from torchmetrics.classification.matthews_corrcoef import ( # noqa: F401
BinaryMatthewsCorrCoef,
MatthewsCorrCoef,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class MulticlassAccuracy(MulticlassStatScores):
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.7500)
tensor(0.8333)
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Expand All @@ -189,7 +189,7 @@ class MulticlassAccuracy(MulticlassStatScores):
... ])
>>> metric = MulticlassAccuracy(num_classes=3)
>>> metric(preds, target)
tensor(0.7500)
tensor(0.8333)
>>> metric = MulticlassAccuracy(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.5000, 1.0000, 1.0000])
Expand All @@ -200,7 +200,7 @@ class MulticlassAccuracy(MulticlassStatScores):
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.3333])
tensor([0.5000, 0.2778])
>>> metric = MulticlassAccuracy(num_classes=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[1.0000, 0.0000, 0.5000],
Expand Down
12 changes: 10 additions & 2 deletions src/torchmetrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from torch import Tensor

from torchmetrics.functional.classification.auc import _auc_compute, _auc_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.compute import _auc_compute, _auc_format_inputs
from torchmetrics.utilities.data import dim_zero_cat


Expand All @@ -28,6 +28,9 @@ class AUC(Metric):
Forward accepts two input tensors that should be 1D and have the same number
of elements
.. note::
This metric has been deprecated in v0.10 and will be removed in v0.11.
Args:
reorder: AUC expects its first input to be sorted. If this is not the case,
setting this argument to ``True`` will use a stable sorting algorithm to
Expand All @@ -47,6 +50,11 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"`torchmetrics.classification.AUC` has been deprecated in v0.10 and will be removed in v0.11."
"A functional version is still available in `torchmetrics.utilities.compute`",
DeprecationWarning,
)

self.reorder = reorder

Expand All @@ -65,7 +73,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
preds: Predictions from model (probabilities, or labels)
target: Ground truth labels
"""
x, y = _auc_update(preds, target)
x, y = _auc_format_inputs(preds, target)

self.x.append(x)
self.y.append(y)
Expand Down
32 changes: 16 additions & 16 deletions src/torchmetrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class MulticlassFBetaScore(MulticlassStatScores):
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3)
>>> metric(preds, target)
tensor(0.7500)
tensor(0.7963)
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.5556, 0.8333, 1.0000])
Expand All @@ -208,7 +208,7 @@ class MulticlassFBetaScore(MulticlassStatScores):
... ])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3)
>>> metric(preds, target)
tensor(0.7500)
tensor(0.7963)
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.5556, 0.8333, 1.0000])
Expand All @@ -219,7 +219,7 @@ class MulticlassFBetaScore(MulticlassStatScores):
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.3333])
tensor([0.4697, 0.2706])
>>> metric = MulticlassFBetaScore(beta=2.0, num_classes=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[0.9091, 0.0000, 0.5000],
Expand All @@ -234,7 +234,7 @@ def __init__(
beta: float,
num_classes: int,
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -320,7 +320,7 @@ class MultilabelFBetaScore(MultilabelStatScores):
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
tensor(0.6111)
>>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None)
>>> metric(preds, target)
tensor([1.0000, 0.0000, 0.8333])
Expand All @@ -331,7 +331,7 @@ class MultilabelFBetaScore(MultilabelStatScores):
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
tensor(0.6111)
>>> metric = MultilabelFBetaScore(beta=2.0, num_labels=3, average=None)
>>> metric(preds, target)
tensor([1.0000, 0.0000, 0.8333])
Expand All @@ -347,7 +347,7 @@ class MultilabelFBetaScore(MultilabelStatScores):
... )
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5882, 0.0000])
tensor([0.5556, 0.0000])
>>> metric = MultilabelFBetaScore(num_labels=3, beta=2.0, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[0.8333, 0.8333, 0.0000],
Expand All @@ -363,7 +363,7 @@ def __init__(
beta: float,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -536,7 +536,7 @@ class MulticlassF1Score(MulticlassFBetaScore):
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassF1Score(num_classes=3)
>>> metric(preds, target)
tensor(0.7500)
tensor(0.7778)
>>> metric = MulticlassF1Score(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.6667, 0.6667, 1.0000])
Expand All @@ -552,7 +552,7 @@ class MulticlassF1Score(MulticlassFBetaScore):
... ])
>>> metric = MulticlassF1Score(num_classes=3)
>>> metric(preds, target)
tensor(0.7500)
tensor(0.7778)
>>> metric = MulticlassF1Score(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.6667, 0.6667, 1.0000])
Expand All @@ -563,7 +563,7 @@ class MulticlassF1Score(MulticlassFBetaScore):
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.3333])
tensor([0.4333, 0.2667])
>>> metric = MulticlassF1Score(num_classes=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[0.8000, 0.0000, 0.5000],
Expand All @@ -577,7 +577,7 @@ def __init__(
self,
num_classes: int,
top_k: int = 1,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down Expand Up @@ -654,7 +654,7 @@ class MultilabelF1Score(MultilabelFBetaScore):
>>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]])
>>> metric = MultilabelF1Score(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
tensor(0.5556)
>>> metric = MultilabelF1Score(num_labels=3, average=None)
>>> metric(preds, target)
tensor([1.0000, 0.0000, 0.6667])
Expand All @@ -665,7 +665,7 @@ class MultilabelF1Score(MultilabelFBetaScore):
>>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]])
>>> metric = MultilabelF1Score(num_labels=3)
>>> metric(preds, target)
tensor(0.6667)
tensor(0.5556)
>>> metric = MultilabelF1Score(num_labels=3, average=None)
>>> metric(preds, target)
tensor([1.0000, 0.0000, 0.6667])
Expand All @@ -681,7 +681,7 @@ class MultilabelF1Score(MultilabelFBetaScore):
... )
>>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.0000])
tensor([0.4444, 0.0000])
>>> metric = MultilabelF1Score(num_labels=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[0.6667, 0.6667, 0.0000],
Expand All @@ -696,7 +696,7 @@ def __init__(
self,
num_labels: int,
threshold: float = 0.5,
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
multidim_average: Literal["global", "samplewise"] = "global",
ignore_index: Optional[int] = None,
validate_args: bool = True,
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class MulticlassHammingDistance(MulticlassStatScores):
>>> preds = torch.tensor([2, 1, 0, 1])
>>> metric = MulticlassHammingDistance(num_classes=3)
>>> metric(preds, target)
tensor(0.2500)
tensor(0.1667)
>>> metric = MulticlassHammingDistance(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.5000, 0.0000, 0.0000])
Expand All @@ -181,7 +181,7 @@ class MulticlassHammingDistance(MulticlassStatScores):
... ])
>>> metric = MulticlassHammingDistance(num_classes=3)
>>> metric(preds, target)
tensor(0.2500)
tensor(0.1667)
>>> metric = MulticlassHammingDistance(num_classes=3, average=None)
>>> metric(preds, target)
tensor([0.5000, 0.0000, 0.0000])
Expand All @@ -192,7 +192,7 @@ class MulticlassHammingDistance(MulticlassStatScores):
>>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]])
>>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise')
>>> metric(preds, target)
tensor([0.5000, 0.6667])
tensor([0.5000, 0.7222])
>>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise', average=None)
>>> metric(preds, target)
tensor([[0.0000, 1.0000, 0.5000],
Expand Down
48 changes: 11 additions & 37 deletions src/torchmetrics/classification/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@
# limitations under the License.
from typing import Any

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.kl_divergence import _kld_compute, _kld_update
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.regression.kl_divergence import KLDivergence as _KLDivergence
from torchmetrics.utilities.prints import rank_zero_warn


class KLDivergence(Metric):
class KLDivergence(_KLDivergence):
r"""Computes the `KL divergence`_:
.. math::
Expand All @@ -46,6 +43,8 @@ class KLDivergence(Metric):
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
.. note::
This metric have been moved to the regression package in v0.10 and this version will be removed in v0.11.
Raises:
TypeError:
Expand All @@ -65,41 +64,16 @@ class KLDivergence(Metric):
tensor(0.0853)
"""
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
total: Tensor

def __init__(
self,
log_prob: bool = False,
reduction: Literal["mean", "sum", "none", None] = "mean",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if not isinstance(log_prob, bool):
raise TypeError(f"Expected argument `log_prob` to be bool but got {log_prob}")
self.log_prob = log_prob

allowed_reduction = ["mean", "sum", "none", None]
if reduction not in allowed_reduction:
raise ValueError(f"Expected argument `reduction` to be one of {allowed_reduction} but got {reduction}")
self.reduction = reduction

if self.reduction in ["mean", "sum"]:
self.add_state("measures", torch.tensor(0.0), dist_reduce_fx="sum")
else:
self.add_state("measures", [], dist_reduce_fx="cat")
self.add_state("total", torch.tensor(0), dist_reduce_fx="sum")

def update(self, p: Tensor, q: Tensor) -> None: # type: ignore
measures, total = _kld_update(p, q, self.log_prob)
if self.reduction is None or self.reduction == "none":
self.measures.append(measures)
else:
self.measures += measures.sum()
self.total += total

def compute(self) -> Tensor:
measures = dim_zero_cat(self.measures) if self.reduction is None or self.reduction == "none" else self.measures
return _kld_compute(measures, self.total, self.reduction)
super().__init__(log_prob, reduction, **kwargs)
rank_zero_warn(
"`torchmetrics.classification.KLDivergence` have been moved to `torchmetrics.regression.KLDivergence`"
" from v0.10 and this version will be removed in v0.11. Please update import paths.",
DeprecationWarning,
)
Loading

0 comments on commit 033e458

Please sign in to comment.