Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Classification 8/n #1175

Merged
merged 22 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/ci_test-conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ jobs:
python ./.github/assistant.py prune-packages requirements/detection.txt torchvision
# import of PILLOW_VERSION which they recently removed in v9.0 in favor of __version__
pip install -q "Pillow<9.0" # It messes with torchvision
pip install -e . -r requirements/devel.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install -e . -r requirements/devel.txt "torch==${{ matrix.pytorch-version }}.*" -f $TORCH_URL
pip list
python -c "from torch import __version__ as ver; assert '.'.join(ver.split('.')[:2]) == '${{ matrix.pytorch-version }}', ver"
env:
TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html

- name: DocTests
working-directory: ./src
Expand Down
24 changes: 24 additions & 0 deletions docs/source/classification/calibration_error.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,32 @@ ________________
.. autoclass:: torchmetrics.CalibrationError
:noindex:

BinaryCalibrationError
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryCalibrationError
:noindex:

MulticlassCalibrationError
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassCalibrationError
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.calibration_error
:noindex:

binary_calibration_error
^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_calibration_error
:noindex:

multiclass_calibration_error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_calibration_error
:noindex:
24 changes: 24 additions & 0 deletions docs/source/classification/hinge_loss.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,32 @@ ________________
.. autoclass:: torchmetrics.HingeLoss
:noindex:

BinaryHingeLoss
^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.BinaryHingeLoss
:noindex:

MulticlassHingeLoss
^^^^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.MulticlassHingeLoss
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.hinge_loss
:noindex:

binary_hinge_loss
^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.binary_hinge_loss
:noindex:

multiclass_hinge_loss
^^^^^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.multiclass_hinge_loss
:noindex:
1 change: 1 addition & 0 deletions requirements/classification_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
netcal # calibration_error
1 change: 1 addition & 0 deletions requirements/devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
-r text_test.txt
-r audio_test.txt
-r detection_test.txt
-r classification_test.txt
8 changes: 8 additions & 0 deletions src/torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
BinaryAccuracy,
BinaryAUROC,
BinaryAveragePrecision,
BinaryCalibrationError,
BinaryCohenKappa,
BinaryConfusionMatrix,
BinaryF1Score,
BinaryFBetaScore,
BinaryHammingDistance,
BinaryHingeLoss,
BinaryJaccardIndex,
BinaryMatthewsCorrCoef,
BinaryPrecision,
Expand Down Expand Up @@ -63,11 +65,13 @@
MulticlassAccuracy,
MulticlassAUROC,
MulticlassAveragePrecision,
MulticlassCalibrationError,
MulticlassCohenKappa,
MulticlassConfusionMatrix,
MulticlassF1Score,
MulticlassFBetaScore,
MulticlassHammingDistance,
MulticlassHingeLoss,
MulticlassJaccardIndex,
MulticlassMatthewsCorrCoef,
MulticlassPrecision,
Expand Down Expand Up @@ -300,6 +304,10 @@
"WordErrorRate",
"WordInfoLost",
"WordInfoPreserved",
"BinaryCalibrationError",
"MulticlassHingeLoss",
"BinaryHingeLoss",
"MulticlassCalibrationError",
"MultilabelCoverageError",
"MultilabelRankingAveragePrecision",
"MultilabelRankingLoss",
Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@
from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401
from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401
from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401
from torchmetrics.classification.calibration_error import ( # noqa: F401
BinaryCalibrationError,
CalibrationError,
MulticlassCalibrationError,
)
from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401
from torchmetrics.classification.dice import Dice # noqa: F401
from torchmetrics.classification.exact_match import MultilabelExactMatch # noqa: F401
Expand All @@ -67,7 +71,7 @@
MulticlassHammingDistance,
MultilabelHammingDistance,
)
from torchmetrics.classification.hinge import HingeLoss # noqa: F401
from torchmetrics.classification.hinge import BinaryHingeLoss, HingeLoss, MulticlassHingeLoss # noqa: F401
from torchmetrics.classification.jaccard import ( # noqa: F401
BinaryJaccardIndex,
JaccardIndex,
Expand Down
207 changes: 205 additions & 2 deletions src/torchmetrics/classification/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,219 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List
from typing import Any, List, Optional

import torch
from torch import Tensor
from typing_extensions import Literal

from torchmetrics.functional.classification.calibration_error import _ce_compute, _ce_update
from torchmetrics.functional.classification.calibration_error import (
_binary_calibration_error_arg_validation,
_binary_calibration_error_tensor_validation,
_binary_calibration_error_update,
_binary_confusion_matrix_format,
_ce_compute,
_ce_update,
_multiclass_calibration_error_arg_validation,
_multiclass_calibration_error_tensor_validation,
_multiclass_calibration_error_update,
_multiclass_confusion_matrix_format,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat


class BinaryCalibrationError(Metric):
r"""`Computes the Top-label Calibration Error`_ for binary tasks. The expected calibration error can be used to
quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches
the actual probabilities of the ground truth distribution.

Three different norms are implemented, each corresponding to variations on the calibration error metric.

.. math::
\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}

.. math::
\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}

.. math::
\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}

Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of
predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed
in an uniform way in the [0,1] range.

Accepts the following input tensors:

- ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
sigmoid per element.
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain {0,1} values (except if `ignore_index` is specified).

Additional dimension ``...`` will be flattened into the batch dimension.

Args:
n_bins: Number of bins to use when computing the metric.
norm: Norm used to compare empirical and expected probability bins.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> from torchmetrics import BinaryCalibrationError
>>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75])
>>> target = torch.tensor([0, 0, 1, 1, 1])
>>> metric = BinaryCalibrationError(n_bins=2, norm='l1')
>>> metric(preds, target)
tensor(0.2900)
>>> metric = BinaryCalibrationError(n_bins=2, norm='l2')
>>> metric(preds, target)
tensor(0.2918)
>>> metric = BinaryCalibrationError(n_bins=2, norm='max')
>>> metric(preds, target)
tensor(0.3167)
"""
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False

def __init__(
self,
n_bins: int = 15,
norm: Literal["l1", "l2", "max"] = "l1",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if validate_args:
_binary_calibration_error_arg_validation(n_bins, norm, ignore_index)
self.validate_args = validate_args
self.n_bins = n_bins
self.norm = norm
self.ignore_index = ignore_index
self.add_state("confidences", [], dist_reduce_fx="cat")
self.add_state("accuracies", [], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
if self.validate_args:
_binary_calibration_error_tensor_validation(preds, target, self.ignore_index)
preds, target = _binary_confusion_matrix_format(
preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False
)
confidences, accuracies = _binary_calibration_error_update(preds, target)
self.confidences.append(confidences)
self.accuracies.append(accuracies)

def compute(self) -> Tensor:
confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)
return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm)


class MulticlassCalibrationError(Metric):
r"""`Computes the Top-label Calibration Error`_ for multiclass tasks. The expected calibration error can be used to
quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches
the actual probabilities of the ground truth distribution.

Three different norms are implemented, each corresponding to variations on the calibration error metric.

.. math::
\text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)}

.. math::
\text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)}

.. math::
\text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)}

Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of
predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed
in an uniform way in the [0,1] range.

Accepts the following input tensors:

- ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each
observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply
softmax per sample.
- ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore
only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified).

Additional dimension ``...`` will be flattened into the batch dimension.

Args:
num_classes: Integer specifing the number of classes
n_bins: Number of bins to use when computing the metric.
norm: Norm used to compare empirical and expected probability bins.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
>>> from torchmetrics import MulticlassCalibrationError
>>> preds = torch.tensor([[0.25, 0.20, 0.55],
... [0.55, 0.05, 0.40],
... [0.10, 0.30, 0.60],
... [0.90, 0.05, 0.05]])
>>> target = torch.tensor([0, 1, 2, 0])
>>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1')
>>> metric(preds, target)
tensor(0.2000)
>>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l2')
>>> metric(preds, target)
tensor(0.2082)
>>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='max')
>>> metric(preds, target)
tensor(0.2333)
"""
is_differentiable: bool = False
higher_is_better: bool = False
full_state_update: bool = False

def __init__(
self,
num_classes: int,
n_bins: int = 15,
norm: Literal["l1", "l2", "max"] = "l1",
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if validate_args:
_multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index)
self.validate_args = validate_args
self.num_classes = num_classes
self.n_bins = n_bins
self.norm = norm
self.ignore_index = ignore_index
self.add_state("confidences", [], dist_reduce_fx="cat")
self.add_state("accuracies", [], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
if self.validate_args:
_multiclass_calibration_error_tensor_validation(preds, target, self.num_classes, self.ignore_index)
preds, target = _multiclass_confusion_matrix_format(
preds, target, ignore_index=self.ignore_index, convert_to_labels=False
)
confidences, accuracies = _multiclass_calibration_error_update(preds, target)
self.confidences.append(confidences)
self.accuracies.append(accuracies)

def compute(self) -> Tensor:
confidences = dim_zero_cat(self.confidences)
accuracies = dim_zero_cat(self.accuracies)
return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm)


# -------------------------- Old stuff --------------------------


class CalibrationError(Metric):
r"""`Computes the Top-label Calibration Error`_
Three different norms are implemented, each corresponding to variations on the calibration error metric.
Expand Down
Loading