From 57ccb9696e959e4768c8ab40dc9cbc43bf01f851 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 28 Jul 2022 20:31:23 +0200 Subject: [PATCH 01/17] changes --- .../functional/classification/hinge.py | 112 +++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 4d57dcc58f1..5c117d2d5ec 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -11,16 +11,126 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from multiprocessing.sharedctypes import Value from typing import Optional, Tuple, Union +from attr import validate import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import DataType, EnumStr +from torchmetrics.utilities.checks import _check_same_shape +def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: + return measure / total + + +def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] = None) -> None: + if not isinstance(squared, bool): + raise ValueError(f"Expected argument `squared` to be an bool but got {squared}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_hinge_loss_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains {0,1} values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be an floating tensor with probability/logit scores," + f" but got tensor with dtype {preds.dtype}" + ) + + +def _binary_hinge_loss_update( + preds: Tensor, + target: Tensor, + squared: bool, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + + target = target.bool() + margin = torch.zeros_like(preds) + margin[target] = preds[target] + margin[~target] = -preds[~target] + + + measures = 1 - margin + measures = torch.clamp(measures, 0) + + if squared: + measures = measures.pow(2) + + total = tensor(target.shape[0], device=target.device) + return measures.sum(dim=0), total + +def binary_hinge_loss( + preds: Tensor, + target: Tensor, + squared: bool = False, + ignore_index: Optional[int] = None, + validate_args: bool = False +) -> Tensor: + if validate_args: + _binary_hinge_loss_arg_validation(squared, ignore_index) + _binary_hinge_loss_tensor_validation(preds, target, ignore_index) + measures, total = _binary_hinge_loss_update(preds, target, squared, ignore_index) + return _hinge_loss_compute(measures, total) + + +def _multiclass_hinge_loss_arg_validation( + num_classes: int, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", + ignore_index: Optional[int] = None, +) -> None: + _binary_hinge_loss_arg_validation(squared, ignore_index) + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + allowed_mm = ("true", "pred", "all", "none", None) + if multiclass_mode not in allowed_mm: + raise ValueError(f"Expected argument `multiclass_mode` to be one of {allowed_mm}, but got {multiclass_mode}.") + + +def multiclass_hinge_loss( + preds: Tensor, + target: Tensor, + num_classes: int, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", + ignore_index: Optional[int] = None, + validate_args: bool = False +) -> Tensor: + if validate_args: + _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) + _multiclass_hinge_loss_tensor_validation(preds, target, num_classes, ignore_index) + measures, total = _multiclass_hinge_loss_update() + return _hinge_loss_compute(measures, total) + + + + +# -------------------------- Old stuff -------------------------- + class MulticlassMode(EnumStr): """Enum to represent possible multiclass modes of hinge. @@ -29,7 +139,7 @@ class MulticlassMode(EnumStr): """ CRAMMER_SINGER = "crammer-singer" - ONE_VS_ALL = "one-vs-all" + ONE_VS_ALL = def _check_shape_and_type_consistency_hinge( From 00a6c1adbd1e84abb4245ac8a4545372b6d4fc08 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 1 Aug 2022 15:06:54 +0200 Subject: [PATCH 02/17] some code --- .../classification/calibration_error.py | 103 +++++++++++++++++- .../classification/confusion_matrix.py | 13 ++- 2 files changed, 110 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 5f08cf73400..6b00da464a0 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -11,14 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from typing import Tuple, Optional import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_format, +) def _binning_with_loop( @@ -82,7 +89,7 @@ def _binning_bucketize( def _ce_compute( confidences: Tensor, accuracies: Tensor, - bin_boundaries: Tensor, + bin_boundaries: Union[Tensor, int], norm: str = "l1", debias: bool = False, ) -> Tensor: @@ -102,6 +109,9 @@ def _ce_compute( Returns: Tensor: Calibration error scalar. """ + if isinstance(bin_boundaries, int): + bin_boundaries = torch.linspace(0, 1, bin_boundaries + 1, dtype=torch.float, device=confidences.device) + if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") @@ -126,6 +136,95 @@ def _ce_compute( return ce +def _binary_calibration_error_arg_validation( + norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, +) -> None: + allowed_norm = ("l1", "l2", "max") + if norm not in allowed_norm: + raise ValueError(f"Expected argument `norm` to be one of {allowed_norm}, but got {norm}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_calibration_error_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tensor: + confidences, accuracies = preds, target == preds.round().int() + return confidences, accuracies + + +def binary_calibration_error( + preds: Tensor, + target: Tensor, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_calibration_error_arg_validation(norm, ignore_index) + _binary_calibration_error_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold=0.0, ignore_index=ignore_index, should_threshold=False) + confidences, accuracies = _binary_calibration_error_update(preds, target) + return _ce_compute(confidences, accuracies, n_bins, norm) + + +def _multiclass_calibration_error_arg_validation( + num_classes: int, norm: Literal["l1", "l2", "max"] = "l1",, ignore_index: Optional[int] = None, +) -> None: + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + allowed_norm = ("l1", "l2", "max") + if norm not in allowed_norm: + raise ValueError(f"Expected argument `norm` to be one of {allowed_norm}, but got {norm}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multiclass_calibration_error_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _multiclass_confusion_matrix_update( + preds: Tensor, target: Tensor, num_classes: int +) -> Tensor: + +def multiclass_calibration_error( + preds: Tensor, + target: Tensor, + num_classes: int, + n_bins: int = 15, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_calibration_error_arg_validation(num_classes, norm, ignore_index) + _multiclass_calibration_error_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, should_threshold=False) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _multiclass_confusion_matrix_compute(confmat, normalize) + +# -------------------------- Old stuff -------------------------- + + + def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: """Given a predictions and targets tensor, computes the confidences of the top-1 prediction and records their correctness. diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 9fdc0460475..7e95cdbbc41 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -116,6 +116,7 @@ def _binary_confusion_matrix_format( target: Tensor, threshold: float = 0.5, ignore_index: Optional[int] = None, + should_threshold: bool = True, ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -134,7 +135,8 @@ def _binary_confusion_matrix_format( if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid preds = preds.sigmoid() - preds = preds > threshold + if should_threshold: + preds = preds > threshold return preds, target @@ -296,7 +298,7 @@ def _multiclass_confusion_matrix_tensor_validation( def _multiclass_confusion_matrix_format( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = None + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None, should_threshold: bool = True, ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -304,10 +306,13 @@ def _multiclass_confusion_matrix_format( - Remove all datapoints that should be ignored """ # Apply argmax if we have one more dimension - if preds.ndim == target.ndim + 1: + if preds.ndim == target.ndim + 1 and should_threshold: preds = preds.argmax(dim=1) - preds = preds.flatten() + if should_threshold: + preds = preds.flatten() + else: + preds = _movedim(preds, 1, -1).reshape(-1, preds.shape[1]) target = target.flatten() if ignore_index is not None: From 89cdb286f0c3d1852e176c08decf43062c3d5c0d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 3 Aug 2022 11:15:51 +0200 Subject: [PATCH 03/17] updates --- .../classification/calibration_error.py | 12 +- .../functional/classification/hinge.py | 2 +- .../classification/test_calibration_error.py | 339 ++++++++++++++---- 3 files changed, 269 insertions(+), 84 deletions(-) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 6b00da464a0..30a0ab59609 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import torch from torch import Tensor @@ -158,7 +158,7 @@ def _binary_calibration_error_tensor_validation( def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tensor: - confidences, accuracies = preds, target == preds.round().int() + confidences, accuracies = preds, (target == preds.round().int()).float() return confidences, accuracies @@ -179,7 +179,7 @@ def binary_calibration_error( def _multiclass_calibration_error_arg_validation( - num_classes: int, norm: Literal["l1", "l2", "max"] = "l1",, ignore_index: Optional[int] = None, + num_classes: int, norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, ) -> None: if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") @@ -201,9 +201,11 @@ def _multiclass_calibration_error_tensor_validation( ) -def _multiclass_confusion_matrix_update( +def _multiclass_calibration_error_update( preds: Tensor, target: Tensor, num_classes: int ) -> Tensor: + pass + def multiclass_calibration_error( preds: Tensor, @@ -219,7 +221,7 @@ def multiclass_calibration_error( _multiclass_calibration_error_tensor_validation(preds, target, num_classes, ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, should_threshold=False) confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) - return _multiclass_confusion_matrix_compute(confmat, normalize) + return _ce_compute(confidences, accuracies, n_bins, norm) # -------------------------- Old stuff -------------------------- diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 5c117d2d5ec..36f30c8b038 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -139,7 +139,7 @@ class MulticlassMode(EnumStr): """ CRAMMER_SINGER = "crammer-singer" - ONE_VS_ALL = + ONE_VS_ALL = "one-vs-all" def _check_shape_and_type_consistency_hinge( diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 52263f2ce11..8a30eab4b7b 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -1,29 +1,40 @@ -import functools +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial import re import numpy as np import pytest from scipy.special import softmax as _softmax -from torchmetrics import CalibrationError -from torchmetrics.functional import calibration_error +from torchmetrics.functional.classification.calibration_error import ( + binary_calibration_error, + multiclass_calibration_error +) from torchmetrics.utilities.checks import _input_format_classification from torchmetrics.utilities.enums import DataType -from unittests.classification.inputs import _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all # TODO: replace this with official sklearn implementation after next sklearn release from unittests.helpers.reference_metrics import _calibration_error as sk_calib -from unittests.helpers.testers import THRESHOLD, MetricTester +from unittests.helpers.testers import THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_calibration(preds, target, n_bins, norm, debias=False): +def _sk_calibration(preds, target, n_bins, norm, debias=False, ignore_index=None): _, _, mode = _input_format_classification(preds, target, threshold=THRESHOLD) sk_preds, sk_target = preds.numpy(), target.numpy() if mode == DataType.BINARY: @@ -46,77 +57,249 @@ def _sk_calibration(preds, target, n_bins, norm, debias=False): return sk_calib(y_true=sk_target, y_prob=sk_preds, norm=norm, n_bins=n_bins, reduce_bias=debias) -@pytest.mark.parametrize("n_bins", [10, 15, 20]) -@pytest.mark.parametrize("norm", ["l1", "l2", "max"]) -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary_logits.preds, _input_binary_logits.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mcls_logits.preds, _input_mcls_logits.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - ], -) -class TestCE(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm): - self.run_class_metric_test( - ddp=ddp, + +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryCalibrationError(MetricTester): + # @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + # @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + # @pytest.mark.parametrize("ddp", [True, False]) + # def test_binary_calibration_error(self, input, ddp, normalize, ignore_index): + # preds, target = input + # if ignore_index is not None: + # target = inject_ignore_index(target, ignore_index) + # self.run_class_metric_test( + # ddp=ddp, + # preds=preds, + # target=target, + # metric_class=BinaryCalibrationError, + # sk_metric=partial(_sk_calibration_error_binary, normalize=normalize, ignore_index=ignore_index), + # metric_args={ + # "threshold": THRESHOLD, + # "normalize": normalize, + # "ignore_index": ignore_index, + # }, + # ) + @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("norm", ["l1", "l2", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_calibration_error_functional(self, input, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( preds=preds, target=target, - metric_class=CalibrationError, - sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), - dist_sync_on_step=dist_sync_on_step, - metric_args={"n_bins": n_bins, "norm": norm}, + metric_functional=binary_calibration_error, + sk_metric=partial(_sk_calibration, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, ) - def test_ce_functional(self, preds, target, n_bins, norm): - self.run_functional_metric_test( - preds, - target, - metric_functional=calibration_error, - sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), - metric_args={"n_bins": n_bins, "norm": norm}, - ) +# def test_binary_calibration_error_differentiability(self, input): +# preds, target = input +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=BinaryCalibrationError, +# metric_functional=binary_calibration_error, +# metric_args={"threshold": THRESHOLD}, +# ) +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_binary_calibration_error_dtype_cpu(self, input, dtype): +# preds, target = input +# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: +# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") +# if (preds < 0).any() and dtype == torch.half: +# pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") +# self.run_precision_test_cpu( +# preds=preds, +# target=target, +# metric_module=BinaryCalibrationError, +# metric_functional=binary_calibration_error, +# metric_args={"threshold": THRESHOLD}, +# dtype=dtype, +# ) -@pytest.mark.parametrize("preds, targets", [(_input_mlb_prob.preds, _input_mlb_prob.target)]) -def test_invalid_input(preds, targets): - for p, t in zip(preds, targets): - with pytest.raises( - ValueError, - match=re.escape( - f"Calibration error is not well-defined for data with size {p.size()} and targets {t.size()}." - ), - ): - calibration_error(p, t) - - -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - ], -) -def test_invalid_norm(preds, target): - with pytest.raises(ValueError, match="Norm l3 is not supported. Please select from l1, l2, or max. "): - calibration_error(preds, target, norm="l3") - - -@pytest.mark.parametrize("n_bins", [-10, -1, "fsd"]) -@pytest.mark.parametrize( - "preds, targets", - [ - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - ], -) -def test_invalid_bins(preds, targets, n_bins): - for p, t in zip(preds, targets): - with pytest.raises(ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}"): - calibration_error(p, t, n_bins=n_bins) +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_binary_calibration_error_dtype_gpu(self, input, dtype): +# preds, target = input +# self.run_precision_test_gpu( +# preds=preds, +# target=target, +# metric_module=BinaryCalibrationError, +# metric_functional=binary_calibration_error, +# metric_args={"threshold": THRESHOLD}, +# dtype=dtype, +# ) + + +# def _sk_calibration_error_multiclass(preds, target, normalize=None, ignore_index=None): +# preds = preds.numpy() +# target = target.numpy() +# if np.issubdtype(preds.dtype, np.floating): +# preds = np.argmax(preds, axis=1) +# preds = preds.flatten() +# target = target.flatten() +# target, preds = remove_ignore_index(target, preds, ignore_index) +# return sk_calibration_error(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) + + +# @pytest.mark.parametrize("input", _multiclass_cases) +# class TestMulticlassCalibrationError(MetricTester): +# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) +# @pytest.mark.parametrize("ddp", [True, False]) +# def test_multiclass_calibration_error(self, input, ddp, normalize, ignore_index): +# preds, target = input +# if ignore_index is not None: +# target = inject_ignore_index(target, ignore_index) +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=MulticlassCalibrationError, +# sk_metric=partial(_sk_calibration_error_multiclass, normalize=normalize, ignore_index=ignore_index), +# metric_args={ +# "num_classes": NUM_CLASSES, +# "normalize": normalize, +# "ignore_index": ignore_index, +# }, +# ) + +# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) +# def test_multiclass_calibration_error_functional(self, input, normalize, ignore_index): +# preds, target = input +# if ignore_index is not None: +# target = inject_ignore_index(target, ignore_index) +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=multiclass_calibration_error, +# sk_metric=partial(_sk_calibration_error_multiclass, normalize=normalize, ignore_index=ignore_index), +# metric_args={ +# "num_classes": NUM_CLASSES, +# "normalize": normalize, +# "ignore_index": ignore_index, +# }, +# ) + +# def test_multiclass_calibration_error_differentiability(self, input): +# preds, target = input +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=MulticlassCalibrationError, +# metric_functional=multiclass_calibration_error, +# metric_args={"num_classes": NUM_CLASSES}, +# ) + +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_multiclass_calibration_error_dtype_cpu(self, input, dtype): +# preds, target = input +# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: +# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") +# self.run_precision_test_cpu( +# preds=preds, +# target=target, +# metric_module=MulticlassCalibrationError, +# metric_functional=multiclass_calibration_error, +# metric_args={"num_classes": NUM_CLASSES}, +# dtype=dtype, +# ) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_multiclass_calibration_error_dtype_gpu(self, input, dtype): +# preds, target = input +# self.run_precision_test_gpu( +# preds=preds, +# target=target, +# metric_module=MulticlassCalibrationError, +# metric_functional=multiclass_calibration_error, +# metric_args={"num_classes": NUM_CLASSES}, +# dtype=dtype, +# ) + + +# -------------------------- Old stuff -------------------------- + +# @pytest.mark.parametrize("n_bins", [10, 15, 20]) +# @pytest.mark.parametrize("norm", ["l1", "l2", "max"]) +# @pytest.mark.parametrize( +# "preds, target", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target), +# (_input_binary_logits.preds, _input_binary_logits.target), +# (_input_mcls_prob.preds, _input_mcls_prob.target), +# (_input_mcls_logits.preds, _input_mcls_logits.target), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target), +# ], +# ) +# class TestCE(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_ce(self, preds, target, n_bins, ddp, dist_sync_on_step, norm): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=CalibrationError, +# sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"n_bins": n_bins, "norm": norm}, +# ) + +# def test_ce_functional(self, preds, target, n_bins, norm): +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=calibration_error, +# sk_metric=functools.partial(_sk_calibration, n_bins=n_bins, norm=norm), +# metric_args={"n_bins": n_bins, "norm": norm}, +# ) + + +# @pytest.mark.parametrize("preds, targets", [(_input_mlb_prob.preds, _input_mlb_prob.target)]) +# def test_invalid_input(preds, targets): +# for p, t in zip(preds, targets): +# with pytest.raises( +# ValueError, +# match=re.escape( +# f"Calibration error is not well-defined for data with size {p.size()} and targets {t.size()}." +# ), +# ): +# calibration_error(p, t) + + +# @pytest.mark.parametrize( +# "preds, target", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target), +# (_input_mcls_prob.preds, _input_mcls_prob.target), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target), +# ], +# ) +# def test_invalid_norm(preds, target): +# with pytest.raises(ValueError, match="Norm l3 is not supported. Please select from l1, l2, or max. "): +# calibration_error(preds, target, norm="l3") + + +# @pytest.mark.parametrize("n_bins", [-10, -1, "fsd"]) +# @pytest.mark.parametrize( +# "preds, targets", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target), +# (_input_mcls_prob.preds, _input_mcls_prob.target), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target), +# ], +# ) +# def test_invalid_bins(preds, targets, n_bins): +# for p, t in zip(preds, targets): +# with pytest.raises(ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}"): +# calibration_error(p, t, n_bins=n_bins) From 0c61683f3ea73bff7b76baff510697ff5474696f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 3 Aug 2022 16:16:11 +0200 Subject: [PATCH 04/17] working calibration --- requirements/classification_test.txt | 1 + requirements/devel.txt | 1 + .../classification/calibration_error.py | 98 ++++++- .../classification/calibration_error.py | 72 +++-- .../classification/confusion_matrix.py | 5 +- .../functional/classification/hinge.py | 30 +- tests/unittests/classification/inputs.py | 159 +++++++---- .../classification/test_calibration_error.py | 265 +++++++++--------- 8 files changed, 399 insertions(+), 232 deletions(-) create mode 100644 requirements/classification_test.txt diff --git a/requirements/classification_test.txt b/requirements/classification_test.txt new file mode 100644 index 00000000000..318026a665c --- /dev/null +++ b/requirements/classification_test.txt @@ -0,0 +1 @@ +netcal # calibration_error diff --git a/requirements/devel.txt b/requirements/devel.txt index 757c79a82ae..05aa5e61b9a 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -15,3 +15,4 @@ -r text_test.txt -r audio_test.txt -r detection_test.txt +-r classification_test.txt diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index f9b7aa23dde..a9526fd5806 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -11,16 +11,110 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional import torch from torch import Tensor -from torchmetrics.functional.classification.calibration_error import _ce_compute, _ce_update +from torchmetrics.functional.classification.calibration_error import ( + _binary_calibration_error_arg_validation, + _binary_calibration_error_tensor_validation, + _binary_calibration_error_update, + _binary_confusion_matrix_format, + _ce_compute, + _ce_update, + _multiclass_calibration_error_arg_validation, + _multiclass_calibration_error_tensor_validation, + _multiclass_calibration_error_update, + _multiclass_confusion_matrix_format, +) from torchmetrics.metric import Metric from torchmetrics.utilities.data import dim_zero_cat +class BinaryCalibrationError(Metric): + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + n_bins: int = 15, + norm: str = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) + self.validate_args = validate_args + self.n_bins = n_bins + self.norm = norm + self.ignore_index = ignore_index + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_calibration_error_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False + ) + confidences, accuracies = _binary_calibration_error_update(preds, target) + self.confidences.append(confidences) + self.accuracies.append(accuracies) + + def compute(self) -> Tensor: + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) + + +class MulticlassCalibrationError(Metric): + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + n_bins: int = 15, + norm: str = "l1", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) + self.validate_args = validate_args + self.num_classes = num_classes + self.n_bins = n_bins + self.norm = norm + self.ignore_index = ignore_index + self.add_state("confidences", [], dist_reduce_fx="cat") + self.add_state("accuracies", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_calibration_error_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format( + preds, target, ignore_index=self.ignore_index, should_threshold=False + ) + confidences, accuracies = _multiclass_calibration_error_update(preds, target) + self.confidences.append(confidences) + self.accuracies.append(accuracies) + + def compute(self) -> Tensor: + confidences = dim_zero_cat(self.confidences) + accuracies = dim_zero_cat(self.accuracies) + return _ce_compute(confidences, accuracies, self.n_bins, norm=self.norm) + + +# -------------------------- Old stuff -------------------------- + + class CalibrationError(Metric): r"""`Computes the Top-label Calibration Error`_ Three different norms are implemented, each corresponding to variations on the calibration error metric. diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 30a0ab59609..2faf01a8f24 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -11,21 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Optional, Union +from typing import Optional, Tuple, Union import torch from torch import Tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import DataType -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 from torchmetrics.functional.classification.confusion_matrix import ( - _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_format, - _multiclass_confusion_matrix_tensor_validation, + _binary_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, ) +from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_8 def _binning_with_loop( @@ -68,6 +68,7 @@ def _binning_bucketize( Returns: tuple with binned accuracy, binned confidence and binned probabilities """ + accuracies = accuracies.to(dtype=confidences.dtype) acc_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) conf_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) count_bin = torch.zeros(len(bin_boundaries) - 1, device=confidences.device, dtype=confidences.dtype) @@ -115,10 +116,11 @@ def _ce_compute( if norm not in {"l1", "l2", "max"}: raise ValueError(f"Norm {norm} is not supported. Please select from l1, l2, or max. ") - if _TORCH_GREATER_EQUAL_1_8: - acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) - else: - acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) + with torch.no_grad(): + if _TORCH_GREATER_EQUAL_1_8: + acc_bin, conf_bin, prop_bin = _binning_bucketize(confidences, accuracies, bin_boundaries) + else: + acc_bin, conf_bin, prop_bin = _binning_with_loop(confidences, accuracies, bin_boundaries) if norm == "l1": ce = torch.sum(torch.abs(acc_bin - conf_bin) * prop_bin) @@ -137,8 +139,12 @@ def _ce_compute( def _binary_calibration_error_arg_validation( - norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, + n_bins: int, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, ) -> None: + if not isinstance(n_bins, int) or n_bins < 1: + raise ValueError(f"Expected argument `n_bins` to be an integer larger than 0, but got {n_bins}") allowed_norm = ("l1", "l2", "max") if norm not in allowed_norm: raise ValueError(f"Expected argument `norm` to be one of {allowed_norm}, but got {norm}.") @@ -158,31 +164,38 @@ def _binary_calibration_error_tensor_validation( def _binary_calibration_error_update(preds: Tensor, target: Tensor) -> Tensor: - confidences, accuracies = preds, (target == preds.round().int()).float() + confidences, accuracies = preds, target return confidences, accuracies def binary_calibration_error( - preds: Tensor, - target: Tensor, - n_bins: int = 15, + preds: Tensor, + target: Tensor, + n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: if validate_args: - _binary_calibration_error_arg_validation(norm, ignore_index) + _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) _binary_calibration_error_tensor_validation(preds, target, ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, threshold=0.0, ignore_index=ignore_index, should_threshold=False) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=ignore_index, should_threshold=False + ) confidences, accuracies = _binary_calibration_error_update(preds, target) return _ce_compute(confidences, accuracies, n_bins, norm) def _multiclass_calibration_error_arg_validation( - num_classes: int, norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, + num_classes: int, + n_bins: int, + norm: Literal["l1", "l2", "max"] = "l1", + ignore_index: Optional[int] = None, ) -> None: if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if not isinstance(n_bins, int) or n_bins < 1: + raise ValueError(f"Expected argument `n_bins` to be an integer larger than 0, but got {n_bins}") allowed_norm = ("l1", "l2", "max") if norm not in allowed_norm: raise ValueError(f"Expected argument `norm` to be one of {allowed_norm}, but got {norm}.") @@ -202,29 +215,34 @@ def _multiclass_calibration_error_tensor_validation( def _multiclass_calibration_error_update( - preds: Tensor, target: Tensor, num_classes: int + preds: Tensor, + target: Tensor, ) -> Tensor: - pass + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.softmax(1) + confidences, predictions = preds.max(dim=1) + accuracies = predictions.eq(target) + return confidences, accuracies def multiclass_calibration_error( - preds: Tensor, - target: Tensor, - num_classes: int, - n_bins: int = 15, + preds: Tensor, + target: Tensor, + num_classes: int, + n_bins: int = 15, norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: if validate_args: - _multiclass_calibration_error_arg_validation(num_classes, norm, ignore_index) + _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) _multiclass_calibration_error_tensor_validation(preds, target, num_classes, ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, should_threshold=False) - confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + confidences, accuracies = _multiclass_calibration_error_update(preds, target) return _ce_compute(confidences, accuracies, n_bins, norm) -# -------------------------- Old stuff -------------------------- +# -------------------------- Old stuff -------------------------- def _ce_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 7e95cdbbc41..b3c144c5f8d 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -298,7 +298,10 @@ def _multiclass_confusion_matrix_tensor_validation( def _multiclass_confusion_matrix_format( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = None, should_threshold: bool = True, + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + should_threshold: bool = True, ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 36f30c8b038..a5f86708e28 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -11,18 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from multiprocessing.sharedctypes import Value from typing import Optional, Tuple, Union -from attr import validate import torch from torch import Tensor, tensor from typing_extensions import Literal -from torchmetrics.utilities.checks import _input_squeeze +from torchmetrics.utilities.checks import _check_same_shape, _input_squeeze from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import DataType, EnumStr -from torchmetrics.utilities.checks import _check_same_shape def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: @@ -36,9 +33,7 @@ def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") -def _binary_hinge_loss_tensor_validation( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = None -) -> None: +def _binary_hinge_loss_tensor_validation(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> None: # Check that they have same shape _check_same_shape(preds, target) @@ -53,14 +48,14 @@ def _binary_hinge_loss_tensor_validation( f"Detected the following values in `target`: {unique_values} but expected only" f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." ) - + if not preds.is_floating_point(): raise ValueError( "Expected argument `preds` to be an floating tensor with probability/logit scores," f" but got tensor with dtype {preds.dtype}" ) - + def _binary_hinge_loss_update( preds: Tensor, target: Tensor, @@ -73,7 +68,6 @@ def _binary_hinge_loss_update( margin[target] = preds[target] margin[~target] = -preds[~target] - measures = 1 - margin measures = torch.clamp(measures, 0) @@ -83,12 +77,13 @@ def _binary_hinge_loss_update( total = tensor(target.shape[0], device=target.device) return measures.sum(dim=0), total + def binary_hinge_loss( preds: Tensor, target: Tensor, squared: bool = False, ignore_index: Optional[int] = None, - validate_args: bool = False + validate_args: bool = False, ) -> Tensor: if validate_args: _binary_hinge_loss_arg_validation(squared, ignore_index) @@ -111,6 +106,14 @@ def _multiclass_hinge_loss_arg_validation( raise ValueError(f"Expected argument `multiclass_mode` to be one of {allowed_mm}, but got {multiclass_mode}.") +def _multiclass_hinge_loss_tensor_validation(): + pass + + +def _multiclass_hinge_loss_update(): + pass + + def multiclass_hinge_loss( preds: Tensor, target: Tensor, @@ -118,7 +121,7 @@ def multiclass_hinge_loss( squared: bool = False, multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", ignore_index: Optional[int] = None, - validate_args: bool = False + validate_args: bool = False, ) -> Tensor: if validate_args: _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) @@ -127,10 +130,9 @@ def multiclass_hinge_loss( return _hinge_loss_compute(measures, total) - - # -------------------------- Old stuff -------------------------- + class MulticlassMode(EnumStr): """Enum to represent possible multiclass modes of hinge. diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index bb3762a0edd..a17d525287a 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -13,6 +13,7 @@ # limitations under the License. from collections import namedtuple +import pytest import torch from torch import Tensor @@ -71,82 +72,136 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: ) _binary_cases = ( - Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-labels]", ), - Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), - Input( - preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + pytest.param( + Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), + id="input[single dim-probs]", ), - Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-logits]", ), - Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-labels]", ), - Input( - preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-probs]", + ), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-logits]", ), ) _multiclass_cases = ( - Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + pytest.param( + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-labels]", ), - Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + pytest.param( + Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-probs]", ), - Input( - preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + pytest.param( + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + id="input[single dim-logits]", ), - Input( - preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + pytest.param( + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-labels]", ), - Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-probs]", ), - Input( - preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + pytest.param( + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + id="input[multi dim-logits]", ), ) _multilabel_cases = ( - Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + id="input[single dim-labels]", ), - Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + id="input[single dim-probs]", ), - Input( - preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + id="input[single dim-logits]", ), - Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + pytest.param( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + id="input[multi dim-labels]", ), - Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + pytest.param( + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + id="input[multi dim-probs]", ), - Input( - preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + pytest.param( + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + id="input[multi dim-logits]", ), ) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 8a30eab4b7b..380c3c244d6 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -12,75 +12,62 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -import re import numpy as np import pytest -from scipy.special import softmax as _softmax +import torch +from netcal.metrics import ECE, MCE +from scipy.special import expit as sigmoid +from scipy.special import softmax +from torchmetrics.classification.calibration_error import BinaryCalibrationError, MulticlassCalibrationError from torchmetrics.functional.classification.calibration_error import ( binary_calibration_error, - multiclass_calibration_error + multiclass_calibration_error, ) -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all - -# TODO: replace this with official sklearn implementation after next sklearn release -from unittests.helpers.reference_metrics import _calibration_error as sk_calib -from unittests.helpers.testers import THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_calibration(preds, target, n_bins, norm, debias=False, ignore_index=None): - _, _, mode = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = preds.numpy(), target.numpy() - if mode == DataType.BINARY: - if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all(): - sk_preds = 1.0 / (1 + np.exp(-sk_preds)) # sigmoid transform - if mode == DataType.MULTICLASS: - if not np.logical_and(0 <= sk_preds, sk_preds <= 1).all(): - sk_preds = _softmax(sk_preds, axis=1) - # binary label is whether or not the predicted class is correct - sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target) - sk_preds = np.max(sk_preds, axis=1) - elif mode == DataType.MULTIDIM_MULTICLASS: - # reshape from shape (N, C, ...) to (N*EXTRA_DIMS, C) - sk_preds = np.transpose(sk_preds, axes=(0, 2, 1)) - sk_preds = sk_preds.reshape(np.prod(sk_preds.shape[:-1]), sk_preds.shape[-1]) - # reshape from shape (N, ...) to (N*EXTRA_DIMS,) - # binary label is whether or not the predicted class is correct - sk_target = np.equal(np.argmax(sk_preds, axis=1), sk_target.flatten()) - sk_preds = np.max(sk_preds, axis=1) - return sk_calib(y_true=sk_target, y_prob=sk_preds, norm=norm, n_bins=n_bins, reduce_bias=debias) - +def _sk_binary_calibration_error(preds, target, n_bins, norm, ignore_index): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + metric = ECE if norm == "l1" else MCE + return metric(n_bins).measure(preds, target) @pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) class TestBinaryCalibrationError(MetricTester): - # @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) - # @pytest.mark.parametrize("ignore_index", [None, -1, 0]) - # @pytest.mark.parametrize("ddp", [True, False]) - # def test_binary_calibration_error(self, input, ddp, normalize, ignore_index): - # preds, target = input - # if ignore_index is not None: - # target = inject_ignore_index(target, ignore_index) - # self.run_class_metric_test( - # ddp=ddp, - # preds=preds, - # target=target, - # metric_class=BinaryCalibrationError, - # sk_metric=partial(_sk_calibration_error_binary, normalize=normalize, ignore_index=ignore_index), - # metric_args={ - # "threshold": THRESHOLD, - # "normalize": normalize, - # "ignore_index": ignore_index, - # }, - # ) @pytest.mark.parametrize("n_bins", [10, 15, 20]) - @pytest.mark.parametrize("norm", ["l1", "l2", "max"]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_calibration_error(self, input, ddp, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryCalibrationError, + sk_metric=partial(_sk_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_binary_calibration_error_functional(self, input, n_bins, norm, ignore_index): preds, target = input @@ -90,7 +77,7 @@ def test_binary_calibration_error_functional(self, input, n_bins, norm, ignore_i preds=preds, target=target, metric_functional=binary_calibration_error, - sk_metric=partial(_sk_calibration, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + sk_metric=partial(_sk_binary_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), metric_args={ "n_bins": n_bins, "norm": norm, @@ -98,96 +85,100 @@ def test_binary_calibration_error_functional(self, input, n_bins, norm, ignore_i }, ) -# def test_binary_calibration_error_differentiability(self, input): -# preds, target = input -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=BinaryCalibrationError, -# metric_functional=binary_calibration_error, -# metric_args={"threshold": THRESHOLD}, -# ) + def test_binary_calibration_error_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryCalibrationError, + metric_functional=binary_calibration_error, + ) -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_binary_calibration_error_dtype_cpu(self, input, dtype): -# preds, target = input -# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: -# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") -# if (preds < 0).any() and dtype == torch.half: -# pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") -# self.run_precision_test_cpu( -# preds=preds, -# target=target, -# metric_module=BinaryCalibrationError, -# metric_functional=binary_calibration_error, -# metric_args={"threshold": THRESHOLD}, -# dtype=dtype, -# ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_calibration_error_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryCalibrationError, + metric_functional=binary_calibration_error, + dtype=dtype, + ) -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_binary_calibration_error_dtype_gpu(self, input, dtype): -# preds, target = input -# self.run_precision_test_gpu( -# preds=preds, -# target=target, -# metric_module=BinaryCalibrationError, -# metric_functional=binary_calibration_error, -# metric_args={"threshold": THRESHOLD}, -# dtype=dtype, -# ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_calibration_error_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryCalibrationError, + metric_functional=binary_calibration_error, + dtype=dtype, + ) -# def _sk_calibration_error_multiclass(preds, target, normalize=None, ignore_index=None): -# preds = preds.numpy() -# target = target.numpy() -# if np.issubdtype(preds.dtype, np.floating): -# preds = np.argmax(preds, axis=1) -# preds = preds.flatten() -# target = target.flatten() -# target, preds = remove_ignore_index(target, preds, ignore_index) -# return sk_calibration_error(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) +def _sk_multiclass_calibration_error(preds, target, n_bins, norm, ignore_index): + preds = preds.numpy() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target, preds = remove_ignore_index(target, preds, ignore_index) + metric = ECE if norm == "l1" else MCE + return metric(n_bins).measure(preds, target) -# @pytest.mark.parametrize("input", _multiclass_cases) -# class TestMulticlassCalibrationError(MetricTester): -# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) -# @pytest.mark.parametrize("ddp", [True, False]) -# def test_multiclass_calibration_error(self, input, ddp, normalize, ignore_index): -# preds, target = input -# if ignore_index is not None: -# target = inject_ignore_index(target, ignore_index) -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=MulticlassCalibrationError, -# sk_metric=partial(_sk_calibration_error_multiclass, normalize=normalize, ignore_index=ignore_index), -# metric_args={ -# "num_classes": NUM_CLASSES, -# "normalize": normalize, -# "ignore_index": ignore_index, -# }, -# ) +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassCalibrationError(MetricTester): + @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_calibration_error(self, input, ddp, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassCalibrationError, + sk_metric=partial(_sk_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("norm", ["l1", "max"]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_calibration_error_functional(self, input, n_bins, norm, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_calibration_error, + sk_metric=partial(_sk_multiclass_calibration_error, n_bins=n_bins, norm=norm, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "n_bins": n_bins, + "norm": norm, + "ignore_index": ignore_index, + }, + ) -# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) -# def test_multiclass_calibration_error_functional(self, input, normalize, ignore_index): -# preds, target = input -# if ignore_index is not None: -# target = inject_ignore_index(target, ignore_index) -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=multiclass_calibration_error, -# sk_metric=partial(_sk_calibration_error_multiclass, normalize=normalize, ignore_index=ignore_index), -# metric_args={ -# "num_classes": NUM_CLASSES, -# "normalize": normalize, -# "ignore_index": ignore_index, -# }, -# ) # def test_multiclass_calibration_error_differentiability(self, input): # preds, target = input @@ -301,5 +292,7 @@ def test_binary_calibration_error_functional(self, input, n_bins, norm, ignore_i # ) # def test_invalid_bins(preds, targets, n_bins): # for p, t in zip(preds, targets): -# with pytest.raises(ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}"): +# with pytest.raises( +# ValueError, match=f"Expected argument `n_bins` to be a int larger than 0 but got {n_bins}" +# ): # calibration_error(p, t, n_bins=n_bins) From 3deb07c9f7d9713d5f6593a1c84e26146f9ee619 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 3 Aug 2022 19:22:27 +0200 Subject: [PATCH 05/17] improve tests --- .../classification/test_calibration_error.py | 71 ++++++++++--------- 1 file changed, 36 insertions(+), 35 deletions(-) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 380c3c244d6..259597d7575 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -179,43 +179,44 @@ def test_multiclass_calibration_error_functional(self, input, n_bins, norm, igno }, ) + def test_multiclass_calibration_error_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassCalibrationError, + metric_functional=multiclass_calibration_error, + metric_args={"num_classes": NUM_CLASSES}, + ) -# def test_multiclass_calibration_error_differentiability(self, input): -# preds, target = input -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=MulticlassCalibrationError, -# metric_functional=multiclass_calibration_error, -# metric_args={"num_classes": NUM_CLASSES}, -# ) - -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_multiclass_calibration_error_dtype_cpu(self, input, dtype): -# preds, target = input -# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: -# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") -# self.run_precision_test_cpu( -# preds=preds, -# target=target, -# metric_module=MulticlassCalibrationError, -# metric_functional=multiclass_calibration_error, -# metric_args={"num_classes": NUM_CLASSES}, -# dtype=dtype, -# ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_calibration_error_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.softmax in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassCalibrationError, + metric_functional=multiclass_calibration_error, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_multiclass_calibration_error_dtype_gpu(self, input, dtype): -# preds, target = input -# self.run_precision_test_gpu( -# preds=preds, -# target=target, -# metric_module=MulticlassCalibrationError, -# metric_functional=multiclass_calibration_error, -# metric_args={"num_classes": NUM_CLASSES}, -# dtype=dtype, -# ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_calibration_error_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassCalibrationError, + metric_functional=multiclass_calibration_error, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) # -------------------------- Old stuff -------------------------- From 2e42c5069b76eceb52e62681af6efc016bf44660 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 3 Aug 2022 19:49:21 +0200 Subject: [PATCH 06/17] docs --- .../classification/calibration_error.rst | 24 +++++++++++++++++++ docs/source/classification/hinge_loss.rst | 24 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/docs/source/classification/calibration_error.rst b/docs/source/classification/calibration_error.rst index b10cb834f5f..e6a9c0a9427 100644 --- a/docs/source/classification/calibration_error.rst +++ b/docs/source/classification/calibration_error.rst @@ -15,8 +15,32 @@ ________________ .. autoclass:: torchmetrics.CalibrationError :noindex: +BinaryCalibrationError +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryCalibrationError + :noindex: + +MulticlassCalibrationError +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassCalibrationError + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.calibration_error :noindex: + +binary_calibration_error +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_calibration_error + :noindex: + +multiclass_calibration_error +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_calibration_error + :noindex: diff --git a/docs/source/classification/hinge_loss.rst b/docs/source/classification/hinge_loss.rst index 4aa5562ab6a..4de979946b8 100644 --- a/docs/source/classification/hinge_loss.rst +++ b/docs/source/classification/hinge_loss.rst @@ -13,8 +13,32 @@ ________________ .. autoclass:: torchmetrics.HingeLoss :noindex: +BinaryHingeLoss +^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryHingeLoss + :noindex: + +MulticlassHingeLoss +^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassHingeLoss + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.hinge_loss :noindex: + +binary_hinge_loss +^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_hinge_loss + :noindex: + +multiclass_hinge_loss +^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_hinge_loss + :noindex: From bed959db48d834ae734c1afca5f91d37c3735b96 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 3 Aug 2022 19:50:12 +0200 Subject: [PATCH 07/17] update --- .../functional/classification/__init__.py | 12 +- tests/unittests/classification/test_hinge.py | 441 +++++++++++++----- 2 files changed, 324 insertions(+), 129 deletions(-) diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 05ab3b92ca9..306c11d9fe3 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -15,7 +15,11 @@ from torchmetrics.functional.classification.auc import auc # noqa: F401 from torchmetrics.functional.classification.auroc import auroc # noqa: F401 from torchmetrics.functional.classification.average_precision import average_precision # noqa: F401 -from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 +from torchmetrics.functional.classification.calibration_error import ( # noqa: F401 + calibration_error, + binary_calibration_error, + multiclass_calibration_error, +) from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 binary_confusion_matrix, @@ -40,7 +44,11 @@ multiclass_hamming_distance, multilabel_hamming_distance, ) -from torchmetrics.functional.classification.hinge import hinge_loss # noqa: F401 +from torchmetrics.functional.classification.hinge import ( # noqa: F401 + hinge_loss, + binary_hinge_loss, + multiclass_hinge_loss, +) from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 from torchmetrics.functional.classification.kl_divergence import kl_divergence # noqa: F401 from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index f495c14d36d..3c1dcd4b473 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -27,130 +27,317 @@ torch.manual_seed(42) -_input_binary = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) -) - -_input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) - -_input_multiclass = Input( - preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), - target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), -) - - -def _sk_hinge(preds, target, squared, multiclass_mode): - sk_preds, sk_target = preds.numpy(), target.numpy() - - if multiclass_mode == MulticlassMode.ONE_VS_ALL: - enc = OneHotEncoder() - enc.fit(sk_target.reshape(-1, 1)) - sk_target = enc.transform(sk_target.reshape(-1, 1)).toarray() - - if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: - sk_target = 2 * sk_target - 1 - - if squared or sk_target.max() != 1 or sk_target.min() != -1: - # Squared not an option in sklearn and infers classes incorrectly with single element, so adapted from source - if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: - margin = sk_target * sk_preds - else: - mask = np.ones_like(sk_preds, dtype=bool) - mask[np.arange(sk_target.shape[0]), sk_target] = False - margin = sk_preds[~mask] - margin -= np.max(sk_preds[mask].reshape(sk_target.shape[0], -1), axis=1) - measures = 1 - margin - measures = np.clip(measures, 0, None) - - if squared: - measures = measures**2 - return measures.mean(axis=0) - if multiclass_mode == MulticlassMode.ONE_VS_ALL: - result = np.zeros(sk_preds.shape[1]) - for i in range(result.shape[0]): - result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i]) - return result - - return sk_hinge(y_true=sk_target, pred_decision=sk_preds) - - -@pytest.mark.parametrize( - "preds, target, squared, multiclass_mode", - [ - (_input_binary.preds, _input_binary.target, False, None), - (_input_binary.preds, _input_binary.target, True, None), - (_input_binary_single.preds, _input_binary_single.target, False, None), - (_input_binary_single.preds, _input_binary_single.target, True, None), - (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.CRAMMER_SINGER), - (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.CRAMMER_SINGER), - (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.ONE_VS_ALL), - (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.ONE_VS_ALL), - ], -) -class TestHinge(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=HingeLoss, - sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "squared": squared, - "multiclass_mode": multiclass_mode, - }, - ) - - def test_hinge_fn(self, preds, target, squared, multiclass_mode): - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), - sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), - ) - - def test_hinge_differentiability(self, preds, target, squared, multiclass_mode): - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=HingeLoss, - metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), - ) - - -_input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) - -_input_binary_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) -) - -_input_multi_different_sizes = Input( - preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)) -) - -_input_extra_dim = Input( - preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) -) - - -@pytest.mark.parametrize( - "preds, target, multiclass_mode", - [ - (_input_multi_target.preds, _input_multi_target.target, None), - (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), - (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), - (_input_extra_dim.preds, _input_extra_dim.target, None), - (_input_multiclass.preds[0], _input_multiclass.target[0], "invalid_mode"), - ], -) -def test_bad_inputs_fn(preds, target, multiclass_mode): - with pytest.raises(ValueError): - _ = hinge_loss(preds, target, multiclass_mode=multiclass_mode) - - -def test_bad_inputs_class(): - with pytest.raises(ValueError): - HingeLoss(multiclass_mode="invalid_mode") + +# def _sk_binary_hinge_loss(preds, target, n_bins, norm, ignore_index): +# preds = preds.numpy().flatten() +# target = target.numpy().flatten() +# if not ((0 < preds) & (preds < 1)).all(): +# preds = sigmoid(preds) +# target, preds = remove_ignore_index(target, preds, ignore_index) +# metric = ECE if norm == "l1" else MCE +# return metric(n_bins).measure(preds, target) + + +# @pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +# class TestBinaryHingeLoss(MetricTester): +# @pytest.mark.parametrize("n_bins", [10, 15, 20]) +# @pytest.mark.parametrize("norm", ["l1", "max"]) +# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) +# @pytest.mark.parametrize("ddp", [True, False]) +# def test_binary_hinge_loss(self, input, ddp, n_bins, norm, ignore_index): +# preds, target = input +# if ignore_index is not None: +# target = inject_ignore_index(target, ignore_index) +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=BinaryHingeLoss, +# sk_metric=partial(_sk_binary_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), +# metric_args={ +# "n_bins": n_bins, +# "norm": norm, +# "ignore_index": ignore_index, +# }, +# ) + +# @pytest.mark.parametrize("n_bins", [10, 15, 20]) +# @pytest.mark.parametrize("norm", ["l1", "max"]) +# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) +# def test_binary_hinge_loss_functional(self, input, n_bins, norm, ignore_index): +# preds, target = input +# if ignore_index is not None: +# target = inject_ignore_index(target, ignore_index) +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=binary_hinge_loss, +# sk_metric=partial(_sk_binary_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), +# metric_args={ +# "n_bins": n_bins, +# "norm": norm, +# "ignore_index": ignore_index, +# }, +# ) + +# def test_binary_hinge_loss_differentiability(self, input): +# preds, target = input +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=BinaryHingeLoss, +# metric_functional=binary_hinge_loss, +# ) + +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_binary_hinge_loss_dtype_cpu(self, input, dtype): +# preds, target = input +# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: +# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") +# if (preds < 0).any() and dtype == torch.half: +# pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") +# self.run_precision_test_cpu( +# preds=preds, +# target=target, +# metric_module=BinaryHingeLoss, +# metric_functional=binary_hinge_loss, +# dtype=dtype, +# ) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_binary_hinge_loss_dtype_gpu(self, input, dtype): +# preds, target = input +# self.run_precision_test_gpu( +# preds=preds, +# target=target, +# metric_module=BinaryHingeLoss, +# metric_functional=binary_hinge_loss, +# dtype=dtype, +# ) + + +# def _sk_multiclass_hinge_loss(preds, target, n_bins, norm, ignore_index): +# preds = preds.numpy() +# target = target.numpy().flatten() +# if not ((0 < preds) & (preds < 1)).all(): +# preds = softmax(preds, 1) +# preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) +# target, preds = remove_ignore_index(target, preds, ignore_index) +# metric = ECE if norm == "l1" else MCE +# return metric(n_bins).measure(preds, target) + + +# @pytest.mark.parametrize( +# "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +# ) +# class TestMulticlassHingeLoss(MetricTester): +# @pytest.mark.parametrize("n_bins", [10, 15, 20]) +# @pytest.mark.parametrize("norm", ["l1", "max"]) +# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) +# @pytest.mark.parametrize("ddp", [True, False]) +# def test_multiclass_hinge_loss(self, input, ddp, n_bins, norm, ignore_index): +# preds, target = input +# if ignore_index is not None: +# target = inject_ignore_index(target, ignore_index) +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=MulticlassHingeLoss, +# sk_metric=partial(_sk_multiclass_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), +# metric_args={ +# "num_classes": NUM_CLASSES, +# "n_bins": n_bins, +# "norm": norm, +# "ignore_index": ignore_index, +# }, +# ) + +# @pytest.mark.parametrize("n_bins", [10, 15, 20]) +# @pytest.mark.parametrize("norm", ["l1", "max"]) +# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) +# def test_multiclass_hinge_loss_functional(self, input, n_bins, norm, ignore_index): +# preds, target = input +# if ignore_index is not None: +# target = inject_ignore_index(target, ignore_index) +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=multiclass_hinge_loss, +# sk_metric=partial(_sk_multiclass_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), +# metric_args={ +# "num_classes": NUM_CLASSES, +# "n_bins": n_bins, +# "norm": norm, +# "ignore_index": ignore_index, +# }, +# ) + +# def test_multiclass_hinge_loss_differentiability(self, input): +# preds, target = input +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=MulticlassHingeLoss, +# metric_functional=multiclass_hinge_loss, +# metric_args={"num_classes": NUM_CLASSES}, +# ) + +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_multiclass_hinge_loss_dtype_cpu(self, input, dtype): +# preds, target = input +# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: +# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") +# self.run_precision_test_cpu( +# preds=preds, +# target=target, +# metric_module=MulticlassHingeLoss, +# metric_functional=multiclass_hinge_loss, +# metric_args={"num_classes": NUM_CLASSES}, +# dtype=dtype, +# ) + +# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") +# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) +# def test_multiclass_hinge_loss_dtype_gpu(self, input, dtype): +# preds, target = input +# self.run_precision_test_gpu( +# preds=preds, +# target=target, +# metric_module=MulticlassHingeLoss, +# metric_functional=multiclass_hinge_loss, +# metric_args={"num_classes": NUM_CLASSES}, +# dtype=dtype, +# ) + + +# -------------------------- Old stuff -------------------------- + +# _input_binary = Input( +# preds=torch.randn(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) +# ) + +# _input_binary_single = Input(preds=torch.randn((NUM_BATCHES, 1)), target=torch.randint(high=2, size=(NUM_BATCHES, 1))) + +# _input_multiclass = Input( +# preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), +# target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), +# ) + + +# def _sk_hinge(preds, target, squared, multiclass_mode): +# sk_preds, sk_target = preds.numpy(), target.numpy() + +# if multiclass_mode == MulticlassMode.ONE_VS_ALL: +# enc = OneHotEncoder() +# enc.fit(sk_target.reshape(-1, 1)) +# sk_target = enc.transform(sk_target.reshape(-1, 1)).toarray() + +# if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: +# sk_target = 2 * sk_target - 1 + +# if squared or sk_target.max() != 1 or sk_target.min() != -1: +# # Squared not an option in sklearn and infers classes incorrectly with single element, so adapted from source +# if sk_preds.ndim == 1 or multiclass_mode == MulticlassMode.ONE_VS_ALL: +# margin = sk_target * sk_preds +# else: +# mask = np.ones_like(sk_preds, dtype=bool) +# mask[np.arange(sk_target.shape[0]), sk_target] = False +# margin = sk_preds[~mask] +# margin -= np.max(sk_preds[mask].reshape(sk_target.shape[0], -1), axis=1) +# measures = 1 - margin +# measures = np.clip(measures, 0, None) + +# if squared: +# measures = measures**2 +# return measures.mean(axis=0) +# if multiclass_mode == MulticlassMode.ONE_VS_ALL: +# result = np.zeros(sk_preds.shape[1]) +# for i in range(result.shape[0]): +# result[i] = sk_hinge(y_true=sk_target[:, i], pred_decision=sk_preds[:, i]) +# return result + +# return sk_hinge(y_true=sk_target, pred_decision=sk_preds) + + +# @pytest.mark.parametrize( +# "preds, target, squared, multiclass_mode", +# [ +# (_input_binary.preds, _input_binary.target, False, None), +# (_input_binary.preds, _input_binary.target, True, None), +# (_input_binary_single.preds, _input_binary_single.target, False, None), +# (_input_binary_single.preds, _input_binary_single.target, True, None), +# (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.CRAMMER_SINGER), +# (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.CRAMMER_SINGER), +# (_input_multiclass.preds, _input_multiclass.target, False, MulticlassMode.ONE_VS_ALL), +# (_input_multiclass.preds, _input_multiclass.target, True, MulticlassMode.ONE_VS_ALL), +# ], +# ) +# class TestHinge(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_hinge_class(self, ddp, dist_sync_on_step, preds, target, squared, multiclass_mode): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=HingeLoss, +# sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "squared": squared, +# "multiclass_mode": multiclass_mode, +# }, +# ) + +# def test_hinge_fn(self, preds, target, squared, multiclass_mode): +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), +# sk_metric=partial(_sk_hinge, squared=squared, multiclass_mode=multiclass_mode), +# ) + +# def test_hinge_differentiability(self, preds, target, squared, multiclass_mode): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=HingeLoss, +# metric_functional=partial(hinge_loss, squared=squared, multiclass_mode=multiclass_mode), +# ) + + +# _input_multi_target = Input(preds=torch.randn(BATCH_SIZE), target=torch.randint(high=2, size=(BATCH_SIZE, 2))) + +# _input_binary_different_sizes = Input( +# preds=torch.randn(BATCH_SIZE * 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) +# ) + +# _input_multi_different_sizes = Input( +# preds=torch.randn(BATCH_SIZE * 2, NUM_CLASSES), target=torch.randint(high=NUM_CLASSES, size=(BATCH_SIZE,)) +# ) + +# _input_extra_dim = Input( +# preds=torch.randn(BATCH_SIZE, NUM_CLASSES, 2), target=torch.randint(high=2, size=(BATCH_SIZE,)) +# ) + + +# @pytest.mark.parametrize( +# "preds, target, multiclass_mode", +# [ +# (_input_multi_target.preds, _input_multi_target.target, None), +# (_input_binary_different_sizes.preds, _input_binary_different_sizes.target, None), +# (_input_multi_different_sizes.preds, _input_multi_different_sizes.target, None), +# (_input_extra_dim.preds, _input_extra_dim.target, None), +# (_input_multiclass.preds[0], _input_multiclass.target[0], "invalid_mode"), +# ], +# ) +# def test_bad_inputs_fn(preds, target, multiclass_mode): +# with pytest.raises(ValueError): +# _ = hinge_loss(preds, target, multiclass_mode=multiclass_mode) + + +# def test_bad_inputs_class(): +# with pytest.raises(ValueError): +# HingeLoss(multiclass_mode="invalid_mode") From 7067bfb585391408a06c99401ba67f571145472f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 4 Aug 2022 11:27:11 +0200 Subject: [PATCH 08/17] init files --- src/torchmetrics/__init__.py | 8 ++++++++ src/torchmetrics/classification/__init__.py | 8 ++++++-- src/torchmetrics/functional/__init__.py | 12 ++++++++++-- .../functional/classification/__init__.py | 4 ++-- 4 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 5230a56d1bd..18e48d57de9 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -26,11 +26,13 @@ ROC, Accuracy, AveragePrecision, + BinaryCalibrationError, BinaryCohenKappa, BinaryConfusionMatrix, BinaryF1Score, BinaryFBetaScore, BinaryHammingDistance, + BinaryHingeLoss, BinaryJaccardIndex, BinaryMatthewsCorrCoef, BinaryPrecision, @@ -54,11 +56,13 @@ LabelRankingAveragePrecision, LabelRankingLoss, MatthewsCorrCoef, + MulticlassCalibrationError, MulticlassCohenKappa, MulticlassConfusionMatrix, MulticlassF1Score, MulticlassFBetaScore, MulticlassHammingDistance, + MulticlassHingeLoss, MulticlassJaccardIndex, MulticlassMatthewsCorrCoef, MulticlassPrecision, @@ -259,4 +263,8 @@ "WordErrorRate", "WordInfoLost", "WordInfoPreserved", + "BinaryCalibrationError", + "MulticlassHingeLoss", + "BinaryHingeLoss", + "MulticlassCalibrationError", ] diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 2a08575de92..64814c06dfa 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -31,7 +31,11 @@ from torchmetrics.classification.binned_precision_recall import BinnedAveragePrecision # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 -from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 +from torchmetrics.classification.calibration_error import ( # noqa: F401 + BinaryCalibrationError, + CalibrationError, + MulticlassCalibrationError, +) from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401 from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import ( # noqa: F401 @@ -50,7 +54,7 @@ MulticlassHammingDistance, MultilabelHammingDistance, ) -from torchmetrics.classification.hinge import HingeLoss # noqa: F401 +from torchmetrics.classification.hinge import BinaryHingeLoss, HingeLoss, MulticlassHingeLoss # noqa: F401 from torchmetrics.classification.jaccard import ( # noqa: F401 BinaryJaccardIndex, JaccardIndex, diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 321311cb60a..e2362b67db7 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -18,7 +18,11 @@ from torchmetrics.functional.classification.auc import auc from torchmetrics.functional.classification.auroc import auroc from torchmetrics.functional.classification.average_precision import average_precision -from torchmetrics.functional.classification.calibration_error import calibration_error +from torchmetrics.functional.classification.calibration_error import ( + binary_calibration_error, + calibration_error, + multiclass_calibration_error, +) from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, @@ -43,7 +47,7 @@ multiclass_hamming_distance, multilabel_hamming_distance, ) -from torchmetrics.functional.classification.hinge import hinge_loss +from torchmetrics.functional.classification.hinge import binary_hinge_loss, hinge_loss, multiclass_hinge_loss from torchmetrics.functional.classification.jaccard import ( binary_jaccard_index, jaccard_index, @@ -253,4 +257,8 @@ "binary_recall", "multiclass_recall", "multilabel_recall", + "binary_calibration_error", + "multiclass_calibration_error", + "binary_hinge_loss", + "multiclass_hinge_loss", ] diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 306c11d9fe3..f405886b366 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -16,8 +16,8 @@ from torchmetrics.functional.classification.auroc import auroc # noqa: F401 from torchmetrics.functional.classification.average_precision import average_precision # noqa: F401 from torchmetrics.functional.classification.calibration_error import ( # noqa: F401 - calibration_error, binary_calibration_error, + calibration_error, multiclass_calibration_error, ) from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 @@ -45,8 +45,8 @@ multilabel_hamming_distance, ) from torchmetrics.functional.classification.hinge import ( # noqa: F401 - hinge_loss, binary_hinge_loss, + hinge_loss, multiclass_hinge_loss, ) from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 From 9833f2ac1b437b833038bac5a77838e90e542483 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 4 Aug 2022 11:28:46 +0200 Subject: [PATCH 09/17] docs calibration error --- .../classification/calibration_error.py | 108 ++++++++++++++++++ .../classification/calibration_error.py | 104 +++++++++++++++++ 2 files changed, 212 insertions(+) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index a9526fd5806..a6d312ffe04 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -33,6 +33,58 @@ class BinaryCalibrationError(Metric): + r"""`Computes the Top-label Calibration Error`_ for binary tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics import BinaryCalibrationError + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> metric = BinaryCalibrationError(n_bins=2, norm='l1') + >>> metric(preds, target) + tensor(0.2900) + >>> metric = BinaryCalibrationError(n_bins=2, norm='l2') + >>> metric(preds, target) + tensor(0.2918) + >>> metric = BinaryCalibrationError(n_bins=2, norm='max') + >>> metric(preds, target) + tensor(0.3167) + """ is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False @@ -72,6 +124,62 @@ def compute(self) -> Tensor: class MulticlassCalibrationError(Metric): + r"""`Computes the Top-label Calibration Error`_ for multiclass tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics import MulticlassCalibrationError + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l1') + >>> metric(preds, target) + tensor(0.2000) + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='l2') + >>> metric(preds, target) + tensor(0.2082) + >>> metric = MulticlassCalibrationError(num_classes=3, n_bins=3, norm='max') + >>> metric(preds, target) + tensor(0.2333) + """ is_differentiable: bool = False higher_is_better: bool = False full_state_update: bool = False diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 2faf01a8f24..0d21825bf7f 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -176,6 +176,56 @@ def binary_calibration_error( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + r"""`Computes the Top-label Calibration Error`_ for binary tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional import binary_calibration_error + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> binary_calibration_error(preds, target, n_bins=2, norm='l1') + tensor(0.2900) + >>> binary_calibration_error(preds, target, n_bins=2, norm='l2') + tensor(0.2918) + >>> binary_calibration_error(preds, target, n_bins=2, norm='max') + tensor(0.3167) + """ if validate_args: _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) _binary_calibration_error_tensor_validation(preds, target, ignore_index) @@ -234,6 +284,60 @@ def multiclass_calibration_error( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + r"""`Computes the Top-label Calibration Error`_ for multiclass tasks. The expected calibration error can be used to + quantify how well a given model is calibrated e.g. how well the predicted output probabilities of the model matches + the actual probabilities of the ground truth distribution. + + Three different norms are implemented, each corresponding to variations on the calibration error metric. + + .. math:: + \text{ECE} = \sum_i^N b_i \|(p_i - c_i)\|, \text{L1 norm (Expected Calibration Error)} + + .. math:: + \text{MCE} = \max_{i} (p_i - c_i), \text{Infinity norm (Maximum Calibration Error)} + + .. math:: + \text{RMSCE} = \sqrt{\sum_i^N b_i(p_i - c_i)^2}, \text{L2 norm (Root Mean Square Calibration Error)} + + Where :math:`p_i` is the top-1 prediction accuracy in bin :math:`i`, :math:`c_i` is the average confidence of + predictions in bin :math:`i`, and :math:`b_i` is the fraction of data points in bin :math:`i`. Bins are constructed + in an uniform way in the [0,1] range. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + n_bins: Number of bins to use when computing the metric. + norm: Norm used to compare empirical and expected probability bins. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional import multiclass_calibration_error + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l1') + tensor(0.2000) + >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='l2') + tensor(0.2082) + >>> multiclass_calibration_error(preds, target, num_classes=3, n_bins=3, norm='max') + tensor(0.2333) + """ if validate_args: _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) _multiclass_calibration_error_tensor_validation(preds, target, num_classes, ignore_index) From ec99c8771232f442ba20ae0e0aee0d250d96bfe7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 4 Aug 2022 15:07:00 +0200 Subject: [PATCH 10/17] update --- src/torchmetrics/classification/hinge.py | 54 ++++- .../functional/classification/hinge.py | 41 ++-- tests/unittests/classification/test_hinge.py | 186 +++++++++--------- 3 files changed, 162 insertions(+), 119 deletions(-) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index 9fb128a6f73..f790340c1e2 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -15,10 +15,62 @@ from torch import Tensor, tensor -from torchmetrics.functional.classification.hinge import MulticlassMode, _hinge_compute, _hinge_update +from torchmetrics.functional.classification.hinge import ( + MulticlassMode, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_hinge_loss_arg_validation, + _binary_hinge_loss_update, + _hinge_compute, + _hinge_loss_compute, + _hinge_update, +) from torchmetrics.metric import Metric +class BinaryHingeLoss(Metric): + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + squared: bool = False, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_hinge_loss_arg_validation(squared, threshold, ignore_index) + self.validate_args = validate_args + self.squared = squared + self.threshold = threshold + self.ignore_index = ignore_index + + self.add_state("measures", default=tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + measures, total = _binary_hinge_loss_update(preds, target, self.squared, self.ignore_index) + self.measures += measures + self.total += total + + def compute(self) -> Tensor: + return _hinge_loss_compute(self.measures, self.total) + + +class MulticlassHingeLoss(Metric): + pass + + +# -------------------------- Old stuff -------------------------- + + class HingeLoss(Metric): r"""Computes the mean `Hinge loss`_, typically used for Support Vector Machines (SVMs). diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index a5f86708e28..0124e8e77e0 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -17,6 +17,10 @@ from torch import Tensor, tensor from typing_extensions import Literal +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, +) from torchmetrics.utilities.checks import _check_same_shape, _input_squeeze from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import DataType, EnumStr @@ -26,43 +30,24 @@ def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: return measure / total -def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] = None) -> None: +def _binary_hinge_loss_arg_validation( + squared: bool, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> None: if not isinstance(squared, bool): raise ValueError(f"Expected argument `squared` to be an bool but got {squared}") + if not (isinstance(threshold, float) and (0 <= threshold <= 1)): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") -def _binary_hinge_loss_tensor_validation(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> None: - # Check that they have same shape - _check_same_shape(preds, target) - - # Check that target only contains {0,1} values or value in ignore_index - unique_values = torch.unique(target) - if ignore_index is None: - check = torch.any((unique_values != 0) & (unique_values != 1)) - else: - check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) - if check: - raise RuntimeError( - f"Detected the following values in `target`: {unique_values} but expected only" - f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." - ) - - if not preds.is_floating_point(): - raise ValueError( - "Expected argument `preds` to be an floating tensor with probability/logit scores," - f" but got tensor with dtype {preds.dtype}" - ) - - def _binary_hinge_loss_update( preds: Tensor, target: Tensor, squared: bool, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - + preds = preds.float() target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] @@ -82,12 +67,14 @@ def binary_hinge_loss( preds: Tensor, target: Tensor, squared: bool = False, + threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = False, ) -> Tensor: if validate_args: - _binary_hinge_loss_arg_validation(squared, ignore_index) - _binary_hinge_loss_tensor_validation(preds, target, ignore_index) + _binary_hinge_loss_arg_validation(squared, threshold, ignore_index) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) measures, total = _binary_hinge_loss_update(preds, target, squared, ignore_index) return _hinge_loss_compute(measures, total) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 3c1dcd4b473..2131042332a 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -16,105 +16,109 @@ import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder -from torchmetrics import HingeLoss +from torchmetrics.classification.hinge import BinaryHingeLoss from torchmetrics.functional import hinge_loss -from torchmetrics.functional.classification.hinge import MulticlassMode -from unittests.classification.inputs import Input -from unittests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, MetricTester +from torchmetrics.functional.classification.hinge import binary_hinge_loss +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases +from unittests.helpers.testers import ( + BATCH_SIZE, + NUM_BATCHES, + NUM_CLASSES, + THRESHOLD, + MetricTester, + inject_ignore_index, + remove_ignore_index, +) torch.manual_seed(42) -# def _sk_binary_hinge_loss(preds, target, n_bins, norm, ignore_index): -# preds = preds.numpy().flatten() -# target = target.numpy().flatten() -# if not ((0 < preds) & (preds < 1)).all(): -# preds = sigmoid(preds) -# target, preds = remove_ignore_index(target, preds, ignore_index) -# metric = ECE if norm == "l1" else MCE -# return metric(n_bins).measure(preds, target) - - -# @pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) -# class TestBinaryHingeLoss(MetricTester): -# @pytest.mark.parametrize("n_bins", [10, 15, 20]) -# @pytest.mark.parametrize("norm", ["l1", "max"]) -# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) -# @pytest.mark.parametrize("ddp", [True, False]) -# def test_binary_hinge_loss(self, input, ddp, n_bins, norm, ignore_index): -# preds, target = input -# if ignore_index is not None: -# target = inject_ignore_index(target, ignore_index) -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=BinaryHingeLoss, -# sk_metric=partial(_sk_binary_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), -# metric_args={ -# "n_bins": n_bins, -# "norm": norm, -# "ignore_index": ignore_index, -# }, -# ) - -# @pytest.mark.parametrize("n_bins", [10, 15, 20]) -# @pytest.mark.parametrize("norm", ["l1", "max"]) -# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) -# def test_binary_hinge_loss_functional(self, input, n_bins, norm, ignore_index): -# preds, target = input -# if ignore_index is not None: -# target = inject_ignore_index(target, ignore_index) -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=binary_hinge_loss, -# sk_metric=partial(_sk_binary_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), -# metric_args={ -# "n_bins": n_bins, -# "norm": norm, -# "ignore_index": ignore_index, -# }, -# ) - -# def test_binary_hinge_loss_differentiability(self, input): -# preds, target = input -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=BinaryHingeLoss, -# metric_functional=binary_hinge_loss, -# ) - -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_binary_hinge_loss_dtype_cpu(self, input, dtype): -# preds, target = input -# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: -# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") -# if (preds < 0).any() and dtype == torch.half: -# pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") -# self.run_precision_test_cpu( -# preds=preds, -# target=target, -# metric_module=BinaryHingeLoss, -# metric_functional=binary_hinge_loss, -# dtype=dtype, -# ) - -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_binary_hinge_loss_dtype_gpu(self, input, dtype): -# preds, target = input -# self.run_precision_test_gpu( -# preds=preds, -# target=target, -# metric_module=BinaryHingeLoss, -# metric_functional=binary_hinge_loss, -# dtype=dtype, -# ) +def _sk_binary_hinge_loss(preds, target, ignore_index): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + + target, preds = remove_ignore_index(target, preds, ignore_index) + target = 2 * target - 1 + return sk_hinge(target, preds, labels=[0, 1]) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryHingeLoss(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_hinge_loss(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryHingeLoss, + sk_metric=partial(_sk_binary_hinge_loss, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_binary_hinge_loss_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_hinge_loss, + sk_metric=partial(_sk_binary_hinge_loss, ignore_index=ignore_index), + metric_args={ + "ignore_index": ignore_index, + }, + ) + + def test_binary_hinge_loss_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryHingeLoss, + metric_functional=binary_hinge_loss, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hinge_loss_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryHingeLoss, + metric_functional=binary_hinge_loss, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hinge_loss_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryHingeLoss, + metric_functional=binary_hinge_loss, + dtype=dtype, + ) # def _sk_multiclass_hinge_loss(preds, target, n_bins, norm, ignore_index): From eccdf09453eacab9687a879fac9a47804b7044c4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 5 Aug 2022 13:47:40 +0200 Subject: [PATCH 11/17] hinge loss --- .../classification/calibration_error.py | 4 +- src/torchmetrics/classification/hinge.py | 162 +++++++++++-- .../classification/calibration_error.py | 4 +- .../classification/confusion_matrix.py | 10 +- .../functional/classification/hinge.py | 163 ++++++++++++-- tests/unittests/classification/test_hinge.py | 212 +++++++++--------- tests/unittests/helpers/testers.py | 4 +- 7 files changed, 407 insertions(+), 152 deletions(-) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index a6d312ffe04..60e8ce961fb 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -111,7 +111,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: _binary_calibration_error_tensor_validation(preds, target, self.ignore_index) preds, target = _binary_confusion_matrix_format( - preds, target, threshold=0.0, ignore_index=self.ignore_index, should_threshold=False + preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False ) confidences, accuracies = _binary_calibration_error_update(preds, target) self.confidences.append(confidences) @@ -208,7 +208,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: _multiclass_calibration_error_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target = _multiclass_confusion_matrix_format( - preds, target, ignore_index=self.ignore_index, should_threshold=False + preds, target, ignore_index=self.ignore_index, convert_to_labels=False ) confidences, accuracies = _multiclass_calibration_error_update(preds, target) self.confidences.append(confidences) diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index f790340c1e2..2faf91860ed 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -13,29 +13,72 @@ # limitations under the License. from typing import Any, Optional, Union -from torch import Tensor, tensor +import torch +from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.hinge import ( MulticlassMode, _binary_confusion_matrix_format, - _binary_confusion_matrix_tensor_validation, _binary_hinge_loss_arg_validation, + _binary_hinge_loss_tensor_validation, _binary_hinge_loss_update, _hinge_compute, _hinge_loss_compute, _hinge_update, + _multiclass_confusion_matrix_format, + _multiclass_hinge_loss_arg_validation, + _multiclass_hinge_loss_tensor_validation, + _multiclass_hinge_loss_update, ) from torchmetrics.metric import Metric class BinaryHingeLoss(Metric): + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks. It is + defined as: + + .. math:: + \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) + + Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics import BinaryHingeLoss + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> metric = BinaryHingeLoss() + >>> metric(preds, target) + tensor(0.6900) + >>> metric = BinaryHingeLoss(squared=True) + >>> metric(preds, target) + tensor(0.6905) + """ is_differentiable: bool = True higher_is_better: bool = False full_state_update: bool = False def __init__( self, - threshold: float = 0.5, squared: bool = False, ignore_index: Optional[int] = None, validate_args: bool = True, @@ -43,20 +86,21 @@ def __init__( ) -> None: super().__init__(**kwargs) if validate_args: - _binary_hinge_loss_arg_validation(squared, threshold, ignore_index) + _binary_hinge_loss_arg_validation(squared, ignore_index) self.validate_args = validate_args self.squared = squared - self.threshold = threshold self.ignore_index = ignore_index - self.add_state("measures", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.add_state("measures", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: - _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) - measures, total = _binary_hinge_loss_update(preds, target, self.squared, self.ignore_index) + _binary_hinge_loss_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=self.ignore_index, convert_to_labels=False + ) + measures, total = _binary_hinge_loss_update(preds, target, self.squared) self.measures += measures self.total += total @@ -65,7 +109,99 @@ def compute(self) -> Tensor: class MulticlassHingeLoss(Metric): - pass + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks + + The metric can be computed in two ways. Either, the definition by Crammer and Singer is used: + + .. math:: + \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) + + Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), + and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can + also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + multiclass_mode: + Determines how to compute the metric + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics import MulticlassHingeLoss + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> metric = MulticlassHingeLoss(num_classes=3) + >>> metric(preds, target) + tensor(0.9125) + >>> metric = MulticlassHingeLoss(num_classes=3, squared=True) + >>> metric(preds, target) + tensor(1.1131) + >>> metric = MulticlassHingeLoss(num_classes=3, multiclass_mode='one-vs-all') + >>> metric(preds, target) + tensor([0.8750, 1.1250, 1.1000]) + """ + is_differentiable: bool = True + higher_is_better: bool = False + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + squared: bool = False, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) + self.validate_args = validate_args + self.num_classes = num_classes + self.squared = squared + self.multiclass_mode = multiclass_mode + self.ignore_index = ignore_index + + self.add_state( + "measures", + default=torch.tensor(0.0) + if self.multiclass_mode == "crammer-singer" + else torch.zeros( + num_classes, + ), + dist_reduce_fx="sum", + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_hinge_loss_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index, convert_to_labels=False) + measures, total = _multiclass_hinge_loss_update(preds, target, self.squared, self.multiclass_mode) + self.measures += measures + self.total += total + + def compute(self) -> Tensor: + return _hinge_loss_compute(self.measures, self.total) # -------------------------- Old stuff -------------------------- @@ -153,8 +289,8 @@ def __init__( ) -> None: super().__init__(**kwargs) - self.add_state("measure", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.add_state("measure", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") if multiclass_mode not in (None, MulticlassMode.CRAMMER_SINGER, MulticlassMode.ONE_VS_ALL): raise ValueError( diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 0d21825bf7f..2ce8e9dd29f 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -230,7 +230,7 @@ def binary_calibration_error( _binary_calibration_error_arg_validation(n_bins, norm, ignore_index) _binary_calibration_error_tensor_validation(preds, target, ignore_index) preds, target = _binary_confusion_matrix_format( - preds, target, threshold=0.0, ignore_index=ignore_index, should_threshold=False + preds, target, threshold=0.0, ignore_index=ignore_index, convert_to_labels=False ) confidences, accuracies = _binary_calibration_error_update(preds, target) return _ce_compute(confidences, accuracies, n_bins, norm) @@ -341,7 +341,7 @@ def multiclass_calibration_error( if validate_args: _multiclass_calibration_error_arg_validation(num_classes, n_bins, norm, ignore_index) _multiclass_calibration_error_tensor_validation(preds, target, num_classes, ignore_index) - preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, should_threshold=False) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False) confidences, accuracies = _multiclass_calibration_error_update(preds, target) return _ce_compute(confidences, accuracies, n_bins, norm) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index b3c144c5f8d..a2a79600789 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -116,7 +116,7 @@ def _binary_confusion_matrix_format( target: Tensor, threshold: float = 0.5, ignore_index: Optional[int] = None, - should_threshold: bool = True, + convert_to_labels: bool = True, ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -135,7 +135,7 @@ def _binary_confusion_matrix_format( if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid preds = preds.sigmoid() - if should_threshold: + if convert_to_labels: preds = preds > threshold return preds, target @@ -301,7 +301,7 @@ def _multiclass_confusion_matrix_format( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None, - should_threshold: bool = True, + convert_to_labels: bool = True, ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. @@ -309,10 +309,10 @@ def _multiclass_confusion_matrix_format( - Remove all datapoints that should be ignored """ # Apply argmax if we have one more dimension - if preds.ndim == target.ndim + 1 and should_threshold: + if preds.ndim == target.ndim + 1 and convert_to_labels: preds = preds.argmax(dim=1) - if should_threshold: + if convert_to_labels: preds = preds.flatten() else: preds = _movedim(preds, 1, -1).reshape(-1, preds.shape[1]) diff --git a/src/torchmetrics/functional/classification/hinge.py b/src/torchmetrics/functional/classification/hinge.py index 0124e8e77e0..623900b23cc 100644 --- a/src/torchmetrics/functional/classification/hinge.py +++ b/src/torchmetrics/functional/classification/hinge.py @@ -20,8 +20,10 @@ from torchmetrics.functional.classification.confusion_matrix import ( _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, ) -from torchmetrics.utilities.checks import _check_same_shape, _input_squeeze +from torchmetrics.utilities.checks import _input_squeeze from torchmetrics.utilities.data import to_onehot from torchmetrics.utilities.enums import DataType, EnumStr @@ -30,24 +32,28 @@ def _hinge_loss_compute(measure: Tensor, total: Tensor) -> Tensor: return measure / total -def _binary_hinge_loss_arg_validation( - squared: bool, threshold: float = 0.5, ignore_index: Optional[int] = None -) -> None: +def _binary_hinge_loss_arg_validation(squared: bool, ignore_index: Optional[int] = None) -> None: if not isinstance(squared, bool): raise ValueError(f"Expected argument `squared` to be an bool but got {squared}") - if not (isinstance(threshold, float) and (0 <= threshold <= 1)): - raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") +def _binary_hinge_loss_tensor_validation(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> None: + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + def _binary_hinge_loss_update( preds: Tensor, target: Tensor, squared: bool, - ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - preds = preds.float() + target = target.bool() margin = torch.zeros_like(preds) margin[target] = preds[target] @@ -67,15 +73,53 @@ def binary_hinge_loss( preds: Tensor, target: Tensor, squared: bool = False, - threshold: float = 0.5, ignore_index: Optional[int] = None, validate_args: bool = False, ) -> Tensor: + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for binary tasks. It is + defined as: + + .. math:: + \text{Hinge loss} = \max(0, 1 - y \times \hat{y}) + + Where :math:`y \in {-1, 1}` is the target, and :math:`\hat{y} \in \mathbb{R}` is the prediction. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional import binary_hinge_loss + >>> preds = torch.tensor([0.25, 0.25, 0.55, 0.75, 0.75]) + >>> target = torch.tensor([0, 0, 1, 1, 1]) + >>> binary_hinge_loss(preds, target) + tensor(0.6900) + >>> binary_hinge_loss(preds, target, squared=True) + tensor(0.6905) + """ if validate_args: - _binary_hinge_loss_arg_validation(squared, threshold, ignore_index) - _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) - measures, total = _binary_hinge_loss_update(preds, target, squared, ignore_index) + _binary_hinge_loss_arg_validation(squared, ignore_index) + _binary_hinge_loss_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format( + preds, target, threshold=0.0, ignore_index=ignore_index, convert_to_labels=False + ) + measures, total = _binary_hinge_loss_update(preds, target, squared) return _hinge_loss_compute(measures, total) @@ -88,17 +132,49 @@ def _multiclass_hinge_loss_arg_validation( _binary_hinge_loss_arg_validation(squared, ignore_index) if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") - allowed_mm = ("true", "pred", "all", "none", None) + allowed_mm = ("crammer-singer", "one-vs-all") if multiclass_mode not in allowed_mm: raise ValueError(f"Expected argument `multiclass_mode` to be one of {allowed_mm}, but got {multiclass_mode}.") -def _multiclass_hinge_loss_tensor_validation(): - pass +def _multiclass_hinge_loss_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + if not preds.is_floating_point(): + raise ValueError( + "Expected argument `preds` to be floating tensor with probabilities/logits" + f" but got tensor with dtype {preds.dtype}" + ) + + +def _multiclass_hinge_loss_update( + preds: Tensor, + target: Tensor, + squared: bool, + multiclass_mode: Literal["crammer-singer", "one-vs-all"] = "crammer-singer", +) -> Tuple[Tensor, Tensor]: + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.softmax(1) + + target = to_onehot(target, max(2, preds.shape[1])).bool() + if multiclass_mode == "crammer-singer": + margin = preds[target] + margin -= torch.max(preds[~target].view(preds.shape[0], -1), dim=1)[0] + else: + target = target.bool() + margin = torch.zeros_like(preds) + margin[target] = preds[target] + margin[~target] = -preds[~target] + + measures = 1 - margin + measures = torch.clamp(measures, 0) + if squared: + measures = measures.pow(2) -def _multiclass_hinge_loss_update(): - pass + total = tensor(target.shape[0], device=target.device) + return measures.sum(dim=0), total def multiclass_hinge_loss( @@ -110,10 +186,59 @@ def multiclass_hinge_loss( ignore_index: Optional[int] = None, validate_args: bool = False, ) -> Tensor: + r"""Computes the mean `Hinge loss`_ typically used for Support Vector Machines (SVMs) for multiclass tasks + + The metric can be computed in two ways. Either, the definition by Crammer and Singer is used: + + .. math:: + \text{Hinge loss} = \max\left(0, 1 - \hat{y}_y + \max_{i \ne y} (\hat{y}_i)\right) + + Where :math:`y \in {0, ..., \mathrm{C}}` is the target class (where :math:`\mathrm{C}` is the number of classes), + and :math:`\hat{y} \in \mathbb{R}^\mathrm{C}` is the predicted output per class. Alternatively, the metric can + also be computed in one-vs-all approach, where each class is valued against all other classes in a binary fashion. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + squared: + If True, this will compute the squared hinge loss. Otherwise, computes the regular hinge loss. + multiclass_mode: + Determines how to compute the metric + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional import multiclass_hinge_loss + >>> preds = torch.tensor([[0.25, 0.20, 0.55], + ... [0.55, 0.05, 0.40], + ... [0.10, 0.30, 0.60], + ... [0.90, 0.05, 0.05]]) + >>> target = torch.tensor([0, 1, 2, 0]) + >>> multiclass_hinge_loss(preds, target, num_classes=3) + tensor(0.9125) + >>> multiclass_hinge_loss(preds, target, num_classes=3, squared=True) + tensor(1.1131) + >>> multiclass_hinge_loss(preds, target, num_classes=3, multiclass_mode='one-vs-all') + tensor([0.8750, 1.1250, 1.1000]) + """ if validate_args: _multiclass_hinge_loss_arg_validation(num_classes, squared, multiclass_mode, ignore_index) _multiclass_hinge_loss_tensor_validation(preds, target, num_classes, ignore_index) - measures, total = _multiclass_hinge_loss_update() + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index, convert_to_labels=False) + measures, total = _multiclass_hinge_loss_update(preds, target, squared, multiclass_mode) return _hinge_loss_compute(measures, total) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index 2131042332a..ec4f883f712 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -17,23 +17,14 @@ import pytest import torch from scipy.special import expit as sigmoid +from scipy.special import softmax from sklearn.metrics import hinge_loss as sk_hinge from sklearn.preprocessing import OneHotEncoder -from torchmetrics.classification.hinge import BinaryHingeLoss -from torchmetrics.functional import hinge_loss -from torchmetrics.functional.classification.hinge import binary_hinge_loss -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -from unittests.classification.inputs import _binary_cases -from unittests.helpers.testers import ( - BATCH_SIZE, - NUM_BATCHES, - NUM_CLASSES, - THRESHOLD, - MetricTester, - inject_ignore_index, - remove_ignore_index, -) +from torchmetrics.classification.hinge import BinaryHingeLoss, MulticlassHingeLoss +from torchmetrics.functional.classification.hinge import binary_hinge_loss, multiclass_hinge_loss +from unittests.classification.inputs import _binary_cases, _multiclass_cases +from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index torch.manual_seed(42) @@ -46,10 +37,10 @@ def _sk_binary_hinge_loss(preds, target, ignore_index): target, preds = remove_ignore_index(target, preds, ignore_index) target = 2 * target - 1 - return sk_hinge(target, preds, labels=[0, 1]) + return sk_hinge(target, preds) -@pytest.mark.parametrize("input", _binary_cases) +@pytest.mark.parametrize("input", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) class TestBinaryHingeLoss(MetricTester): @pytest.mark.parametrize("ignore_index", [None, -1]) @pytest.mark.parametrize("ddp", [True, False]) @@ -64,7 +55,6 @@ def test_binary_hinge_loss(self, input, ddp, ignore_index): metric_class=BinaryHingeLoss, sk_metric=partial(_sk_binary_hinge_loss, ignore_index=ignore_index), metric_args={ - "threshold": THRESHOLD, "ignore_index": ignore_index, }, ) @@ -96,10 +86,8 @@ def test_binary_hinge_loss_differentiability(self, input): @pytest.mark.parametrize("dtype", [torch.half, torch.double]) def test_binary_hinge_loss_dtype_cpu(self, input, dtype): preds, target = input - if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: - pytest.xfail(reason="half support of core ops not support before pytorch v1.6") - if (preds < 0).any() and dtype == torch.half: - pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + if dtype == torch.half: + pytest.xfail(reason="torch.clamp does not support cpu + half") self.run_precision_test_cpu( preds=preds, target=target, @@ -121,99 +109,105 @@ def test_binary_hinge_loss_dtype_gpu(self, input, dtype): ) -# def _sk_multiclass_hinge_loss(preds, target, n_bins, norm, ignore_index): -# preds = preds.numpy() -# target = target.numpy().flatten() -# if not ((0 < preds) & (preds < 1)).all(): -# preds = softmax(preds, 1) -# preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) -# target, preds = remove_ignore_index(target, preds, ignore_index) -# metric = ECE if norm == "l1" else MCE -# return metric(n_bins).measure(preds, target) - +def _sk_multiclass_hinge_loss(preds, target, multiclass_mode, ignore_index): + preds = preds.numpy() + target = target.numpy().flatten() + if not ((0 < preds) & (preds < 1)).all(): + preds = softmax(preds, 1) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target, preds = remove_ignore_index(target, preds, ignore_index) -# @pytest.mark.parametrize( -# "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) -# ) -# class TestMulticlassHingeLoss(MetricTester): -# @pytest.mark.parametrize("n_bins", [10, 15, 20]) -# @pytest.mark.parametrize("norm", ["l1", "max"]) -# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) -# @pytest.mark.parametrize("ddp", [True, False]) -# def test_multiclass_hinge_loss(self, input, ddp, n_bins, norm, ignore_index): -# preds, target = input -# if ignore_index is not None: -# target = inject_ignore_index(target, ignore_index) -# self.run_class_metric_test( -# ddp=ddp, -# preds=preds, -# target=target, -# metric_class=MulticlassHingeLoss, -# sk_metric=partial(_sk_multiclass_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), -# metric_args={ -# "num_classes": NUM_CLASSES, -# "n_bins": n_bins, -# "norm": norm, -# "ignore_index": ignore_index, -# }, -# ) + if multiclass_mode == "one-vs-all": + enc = OneHotEncoder() + enc.fit(target.reshape(-1, 1)) + target = enc.transform(target.reshape(-1, 1)).toarray() + target = 2 * target - 1 + result = np.zeros(preds.shape[1]) + for i in range(result.shape[0]): + result[i] = sk_hinge(y_true=target[:, i], pred_decision=preds[:, i]) + return result + else: + return sk_hinge(target, preds) + + +@pytest.mark.parametrize( + "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassHingeLoss(MetricTester): + @pytest.mark.parametrize("multiclass_mode", ["crammer-singer", "one-vs-all"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_hinge_loss(self, input, ddp, multiclass_mode, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassHingeLoss, + sk_metric=partial(_sk_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "multiclass_mode": multiclass_mode, + "ignore_index": ignore_index, + }, + ) -# @pytest.mark.parametrize("n_bins", [10, 15, 20]) -# @pytest.mark.parametrize("norm", ["l1", "max"]) -# @pytest.mark.parametrize("ignore_index", [None, -1, 0]) -# def test_multiclass_hinge_loss_functional(self, input, n_bins, norm, ignore_index): -# preds, target = input -# if ignore_index is not None: -# target = inject_ignore_index(target, ignore_index) -# self.run_functional_metric_test( -# preds=preds, -# target=target, -# metric_functional=multiclass_hinge_loss, -# sk_metric=partial(_sk_multiclass_hinge_loss, n_bins=n_bins, norm=norm, ignore_index=ignore_index), -# metric_args={ -# "num_classes": NUM_CLASSES, -# "n_bins": n_bins, -# "norm": norm, -# "ignore_index": ignore_index, -# }, -# ) + @pytest.mark.parametrize("multiclass_mode", ["crammer-singer", "one-vs-all"]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_multiclass_hinge_loss_functional(self, input, multiclass_mode, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_hinge_loss, + sk_metric=partial(_sk_multiclass_hinge_loss, multiclass_mode=multiclass_mode, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "multiclass_mode": multiclass_mode, + "ignore_index": ignore_index, + }, + ) -# def test_multiclass_hinge_loss_differentiability(self, input): -# preds, target = input -# self.run_differentiability_test( -# preds=preds, -# target=target, -# metric_module=MulticlassHingeLoss, -# metric_functional=multiclass_hinge_loss, -# metric_args={"num_classes": NUM_CLASSES}, -# ) + def test_multiclass_hinge_loss_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassHingeLoss, + metric_functional=multiclass_hinge_loss, + metric_args={"num_classes": NUM_CLASSES}, + ) -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_multiclass_hinge_loss_dtype_cpu(self, input, dtype): -# preds, target = input -# if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: -# pytest.xfail(reason="half support of core ops not support before pytorch v1.6") -# self.run_precision_test_cpu( -# preds=preds, -# target=target, -# metric_module=MulticlassHingeLoss, -# metric_functional=multiclass_hinge_loss, -# metric_args={"num_classes": NUM_CLASSES}, -# dtype=dtype, -# ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hinge_loss_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half: + pytest.xfail(reason="torch.clamp does not support cpu + half") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassHingeLoss, + metric_functional=multiclass_hinge_loss, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -# @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") -# @pytest.mark.parametrize("dtype", [torch.half, torch.double]) -# def test_multiclass_hinge_loss_dtype_gpu(self, input, dtype): -# preds, target = input -# self.run_precision_test_gpu( -# preds=preds, -# target=target, -# metric_module=MulticlassHingeLoss, -# metric_functional=multiclass_hinge_loss, -# metric_args={"num_classes": NUM_CLASSES}, -# dtype=dtype, -# ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hinge_loss_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassHingeLoss, + metric_functional=multiclass_hinge_loss, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) # -------------------------- Old stuff -------------------------- diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 4f21c488203..1517064f5ab 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -633,8 +633,8 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: return x idx = torch.randperm(x.numel()) x = deepcopy(x) - # randomly set either element {3, 4, 5} to the ignore index value - skip = torch.randint(3, 6, (1,)).item() + # randomly set either element {9, 10} to the ignore index value + skip = torch.randint(9, 11, (1,)).item() x.view(-1)[idx[::skip]] = ignore_index return x From bc07fbe3e31605c3d3837d68b7ba56b4d49ddc72 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 6 Aug 2022 14:35:32 +0200 Subject: [PATCH 12/17] docs --- docs/source/classification/calibration_error.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/classification/calibration_error.rst b/docs/source/classification/calibration_error.rst index e6a9c0a9427..78da990346a 100644 --- a/docs/source/classification/calibration_error.rst +++ b/docs/source/classification/calibration_error.rst @@ -40,7 +40,7 @@ binary_calibration_error :noindex: multiclass_calibration_error -^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: torchmetrics.functional.multiclass_calibration_error :noindex: From f45bebf4944254dd52f8d50716d7f50e0cd8abbf Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 15 Aug 2022 20:47:41 +0200 Subject: [PATCH 13/17] fix type issue --- src/torchmetrics/classification/calibration_error.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index 60e8ce961fb..1bbcfcf59b0 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -15,6 +15,7 @@ import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.calibration_error import ( _binary_calibration_error_arg_validation, @@ -92,7 +93,7 @@ class BinaryCalibrationError(Metric): def __init__( self, n_bins: int = 15, - norm: str = "l1", + norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -188,7 +189,7 @@ def __init__( self, num_classes: int, n_bins: int = 15, - norm: str = "l1", + norm: Literal["l1", "l2", "max"] = "l1", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, From 6504e4374b743412c3199848a88dfa511535a82f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 17 Aug 2022 16:32:12 +0200 Subject: [PATCH 14/17] fix --- tests/unittests/classification/test_calibration_error.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 259597d7575..54949e18259 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -137,7 +137,7 @@ def _sk_multiclass_calibration_error(preds, target, n_bins, norm, ignore_index): "input", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) ) class TestMulticlassCalibrationError(MetricTester): - @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("n_bins", [15, 20]) @pytest.mark.parametrize("norm", ["l1", "max"]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) @@ -159,7 +159,7 @@ def test_multiclass_calibration_error(self, input, ddp, n_bins, norm, ignore_ind }, ) - @pytest.mark.parametrize("n_bins", [10, 15, 20]) + @pytest.mark.parametrize("n_bins", [15, 20]) @pytest.mark.parametrize("norm", ["l1", "max"]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_multiclass_calibration_error_functional(self, input, n_bins, norm, ignore_index): From 315aee86887d38c8aa5fc53bcb85f8920f87a89f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 17 Aug 2022 17:12:50 +0200 Subject: [PATCH 15/17] cast dtype --- src/torchmetrics/functional/classification/calibration_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/calibration_error.py b/src/torchmetrics/functional/classification/calibration_error.py index 2ce8e9dd29f..d56eeb9190e 100644 --- a/src/torchmetrics/functional/classification/calibration_error.py +++ b/src/torchmetrics/functional/classification/calibration_error.py @@ -272,7 +272,7 @@ def _multiclass_calibration_error_update( preds = preds.softmax(1) confidences, predictions = preds.max(dim=1) accuracies = predictions.eq(target) - return confidences, accuracies + return confidences.float(), accuracies.float() def multiclass_calibration_error( From bd15a77c46c2c32e4ee5b5086a678187ce3ace32 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 19 Aug 2022 09:16:56 +0200 Subject: [PATCH 16/17] ci --- .github/workflows/ci_test-conda.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci_test-conda.yml b/.github/workflows/ci_test-conda.yml index 01d38cc97ad..3cb0cec5677 100644 --- a/.github/workflows/ci_test-conda.yml +++ b/.github/workflows/ci_test-conda.yml @@ -77,9 +77,11 @@ jobs: python ./.github/assistant.py prune-packages requirements/detection.txt torchvision # import of PILLOW_VERSION which they recently removed in v9.0 in favor of __version__ pip install -q "Pillow<9.0" # It messes with torchvision - pip install -e . -r requirements/devel.txt -f https://download.pytorch.org/whl/cpu/torch_stable.html + pip install -e . -r requirements/devel.txt "torch==${{ matrix.pytorch-version }}.*" -f $TORCH_URL pip list python -c "from torch import __version__ as ver; assert '.'.join(ver.split('.')[:2]) == '${{ matrix.pytorch-version }}', ver" + env: + TORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html - name: DocTests working-directory: ./src From 7ca13255a41e7857e76cd45847fbed6087359dee Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 20 Aug 2022 19:47:09 +0200 Subject: [PATCH 17/17] skip non supported cpu half tests --- tests/unittests/classification/test_calibration_error.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index 54949e18259..e3dc891f43f 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -25,7 +25,7 @@ binary_calibration_error, multiclass_calibration_error, ) -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6, _TORCH_GREATER_EQUAL_1_9 from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all from unittests.helpers.testers import NUM_CLASSES, MetricTester, inject_ignore_index, remove_ignore_index @@ -192,8 +192,8 @@ def test_multiclass_calibration_error_differentiability(self, input): @pytest.mark.parametrize("dtype", [torch.half, torch.double]) def test_multiclass_calibration_error_dtype_cpu(self, input, dtype): preds, target = input - if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: - pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_9: + pytest.xfail(reason="torch.max in metric not supported before pytorch v1.9 for cpu + half") if (preds < 0).any() and dtype == torch.half: pytest.xfail(reason="torch.softmax in metric does not support cpu + half precision") self.run_precision_test_cpu(