From 7c798012b91a549cc813c863b0f0e42c4ee94d58 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 27 May 2022 16:05:44 +0200 Subject: [PATCH 01/74] base structure --- tests/classification_new/__init__.py | 13 +++++++++++++ tests/classification_new/test_confusion_matrix.py | 13 +++++++++++++ torchmetrics/classification_new/__init__.py | 13 +++++++++++++ torchmetrics/classification_new/confusion_matrix.py | 13 +++++++++++++ .../functional/classification_new/__init__.py | 13 +++++++++++++ .../classification_new/confusion_matrix.py | 13 +++++++++++++ 6 files changed, 78 insertions(+) create mode 100644 tests/classification_new/__init__.py create mode 100644 tests/classification_new/test_confusion_matrix.py create mode 100644 torchmetrics/classification_new/__init__.py create mode 100644 torchmetrics/classification_new/confusion_matrix.py create mode 100644 torchmetrics/functional/classification_new/__init__.py create mode 100644 torchmetrics/functional/classification_new/confusion_matrix.py diff --git a/tests/classification_new/__init__.py b/tests/classification_new/__init__.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/tests/classification_new/__init__.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/classification_new/test_confusion_matrix.py b/tests/classification_new/test_confusion_matrix.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/tests/classification_new/test_confusion_matrix.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/torchmetrics/classification_new/__init__.py b/torchmetrics/classification_new/__init__.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/torchmetrics/classification_new/__init__.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/torchmetrics/classification_new/confusion_matrix.py b/torchmetrics/classification_new/confusion_matrix.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/torchmetrics/classification_new/confusion_matrix.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/torchmetrics/functional/classification_new/__init__.py b/torchmetrics/functional/classification_new/__init__.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/torchmetrics/functional/classification_new/__init__.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/torchmetrics/functional/classification_new/confusion_matrix.py b/torchmetrics/functional/classification_new/confusion_matrix.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/torchmetrics/functional/classification_new/confusion_matrix.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 121b9fabacef8bcb9d8b691f31fcfda9123b783d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 27 May 2022 16:08:11 +0200 Subject: [PATCH 02/74] bincount --- torchmetrics/utilities/compute.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchmetrics/utilities/compute.py b/torchmetrics/utilities/compute.py index f496baff818..73faf05f550 100644 --- a/torchmetrics/utilities/compute.py +++ b/torchmetrics/utilities/compute.py @@ -38,3 +38,21 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: res = x * torch.log(y) res[x == 0] = 0.0 return res + + +def _bincount(x: Tensor, minlength: int) -> Tensor: + """``torch.bincount`` is currently slow on GPU and non-deterministic. This function instead relies on + broadcasting and summation which is fast but uses more memory. + + Args: + x: tensor to count + minlength: minimum length to count + + Returns: + Number of occurrences for each unique element in x + """ + if x.is_cuda: + labels = torch.arange(minlength, dtype=x.dtype, device=x.device).unsqueeze(1) + return (x.unsqueeze(0) == labels).sum(dim=-1) + else: + return torch.bincount(x, minlength=minlength) From 1074281043b25953389f71111b23e7fbea04441e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 27 May 2022 16:56:18 +0200 Subject: [PATCH 03/74] binary --- .../classification_new/confusion_matrix.py | 48 +++++++ .../classification_new/confusion_matrix.py | 121 ++++++++++++++++++ torchmetrics/utilities/checks.py | 4 +- 3 files changed, 172 insertions(+), 1 deletion(-) diff --git a/torchmetrics/classification_new/confusion_matrix.py b/torchmetrics/classification_new/confusion_matrix.py index d7aa17d7f84..6c652f7b41f 100644 --- a/torchmetrics/classification_new/confusion_matrix.py +++ b/torchmetrics/classification_new/confusion_matrix.py @@ -11,3 +11,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + +import torch +from torch import Tensor + +from torchmetrics.functional.classification_new.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_compute, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, +) +from torchmetrics.metric import Metric + + +class BinaryConfusionMatrix(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + confmat: Tensor + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: bool = False, + validate_args: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(2, 2), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + self.confmat += confmat + + def compute(self) -> Tensor: + return _binary_confusion_matrix_compute(self.confmat, self.normalize) diff --git a/torchmetrics/functional/classification_new/confusion_matrix.py b/torchmetrics/functional/classification_new/confusion_matrix.py index d7aa17d7f84..de07628ab4d 100644 --- a/torchmetrics/functional/classification_new/confusion_matrix.py +++ b/torchmetrics/functional/classification_new/confusion_matrix.py @@ -11,3 +11,124 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape +from torchmetrics.utilities.compute import _bincount +from torchmetrics.utilities.prints import rank_zero_warn + + +def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[str] = None, multilabel: bool = False) -> Tensor: + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") + if normalize is not None and normalize != "none": + confmat = confmat.float() if not confmat.is_floating_point() else confmat + if normalize == "true": + confmat = confmat / confmat.sum(axis=2 if multilabel else 1, keepdim=True) + elif normalize == "pred": + confmat = confmat / confmat.sum(axis=1 if multilabel else 0, keepdim=True) + elif normalize == "all": + confmat = confmat / confmat.sum(axis=[1, 2] if multilabel else [0, 1]) + + nan_elements = confmat[torch.isnan(confmat)].nelement() + if nan_elements != 0: + confmat[torch.isnan(confmat)] = 0 + rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.") + return confmat + + +def _binary_confusion_matrix_arg_validation( + threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None +) -> None: + """Validate non tensor input.""" + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _binary_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = bool +) -> None: + """Validate tensor input.""" + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + "Detected the following values in `target`: {unique_values} but expected only" + " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + +def _binary_confusion_matrix_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format.""" + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + return preds, target + + +def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: + """Calculate confusion matrix on current input.""" + unique_mapping = (target * 2 + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=4) + return bins.reshape(2, 2) + + +def _binary_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + """Calculate final confusion matrix.""" + return _confusion_matrix_reduce(confmat, normalize, multilabel=False) + + +def binary_confusion_matrix( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + _binary_confusion_matrix_tensor_validation(preds, target, threshold, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _binary_confusion_matrix_compute(confmat, normalize) diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index 16cb4a267c0..bf853e1226c 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -30,7 +30,9 @@ def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool: def _check_same_shape(preds: Tensor, target: Tensor) -> None: """Check that predictions and target have the same shape, else raise error.""" if preds.shape != target.shape: - raise RuntimeError("Predictions and targets are expected to have the same shape") + raise RuntimeError( + "Predictions and targets are expected to have the same shape," " but got {preds.shape} and {target.shape}." + ) def _basic_input_validation( From 4f4dad6d6edab55aaba19a0153cd83b381b1d0b1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 May 2022 14:04:42 +0200 Subject: [PATCH 04/74] files --- tests/classification_new/test_stat_scores.py | 13 +++++++++++++ torchmetrics/classification_new/stat_scores.py | 13 +++++++++++++ .../functional/classification_new/stat_scores.py | 13 +++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 tests/classification_new/test_stat_scores.py create mode 100644 torchmetrics/classification_new/stat_scores.py create mode 100644 torchmetrics/functional/classification_new/stat_scores.py diff --git a/tests/classification_new/test_stat_scores.py b/tests/classification_new/test_stat_scores.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/tests/classification_new/test_stat_scores.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/torchmetrics/classification_new/stat_scores.py b/torchmetrics/classification_new/stat_scores.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/torchmetrics/classification_new/stat_scores.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/torchmetrics/functional/classification_new/stat_scores.py b/torchmetrics/functional/classification_new/stat_scores.py new file mode 100644 index 00000000000..d7aa17d7f84 --- /dev/null +++ b/torchmetrics/functional/classification_new/stat_scores.py @@ -0,0 +1,13 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From e58c9a598655a2594af1456e33ccd71aa4ce227d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sat, 28 May 2022 15:28:49 +0200 Subject: [PATCH 05/74] stat score --- .../classification_new/confusion_matrix.py | 1 - .../classification_new/stat_scores.py | 65 ++++++++++ .../classification_new/confusion_matrix.py | 2 +- .../classification_new/stat_scores.py | 116 ++++++++++++++++++ 4 files changed, 182 insertions(+), 2 deletions(-) diff --git a/torchmetrics/classification_new/confusion_matrix.py b/torchmetrics/classification_new/confusion_matrix.py index 6c652f7b41f..bd27bcaddbe 100644 --- a/torchmetrics/classification_new/confusion_matrix.py +++ b/torchmetrics/classification_new/confusion_matrix.py @@ -30,7 +30,6 @@ class BinaryConfusionMatrix(Metric): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False - confmat: Tensor def __init__( self, diff --git a/torchmetrics/classification_new/stat_scores.py b/torchmetrics/classification_new/stat_scores.py index d7aa17d7f84..d65694cb140 100644 --- a/torchmetrics/classification_new/stat_scores.py +++ b/torchmetrics/classification_new/stat_scores.py @@ -11,3 +11,68 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, Optional + +import torch +from torch import Tensor + +from torchmetrics.functional.classification_new.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_compute, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, +) +from torchmetrics.metric import Metric + + +class BinaryStatScores(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + self.threshold = threshold + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + if self.multidim_average == "samplewise": + self.add_state("tp", [], dist_reduce_fx="cat") + self.add_state("fp", [], dist_reduce_fx="cat") + self.add_state("tn", [], dist_reduce_fx="cat") + self.add_state("fn", [], dist_reduce_fx="cat") + else: + self.add_state("tp", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("fp", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("tn", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("fn", torch.zeros(1), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) + preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) + if self.multidim_average == "samplewise": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def compute(self) -> Tensor: + return _binary_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) diff --git a/torchmetrics/functional/classification_new/confusion_matrix.py b/torchmetrics/functional/classification_new/confusion_matrix.py index de07628ab4d..f17ffcfd15f 100644 --- a/torchmetrics/functional/classification_new/confusion_matrix.py +++ b/torchmetrics/functional/classification_new/confusion_matrix.py @@ -128,7 +128,7 @@ def binary_confusion_matrix( ) -> Tensor: if validate_args: _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) - _binary_confusion_matrix_tensor_validation(preds, target, threshold, ignore_index) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) confmat = _binary_confusion_matrix_update(preds, target) return _binary_confusion_matrix_compute(confmat, normalize) diff --git a/torchmetrics/functional/classification_new/stat_scores.py b/torchmetrics/functional/classification_new/stat_scores.py index d7aa17d7f84..d1d6dc59452 100644 --- a/torchmetrics/functional/classification_new/stat_scores.py +++ b/torchmetrics/functional/classification_new/stat_scores.py @@ -11,3 +11,119 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from torchmetrics.utilities.checks import _check_same_shape + + +def _binary_stat_scores_arg_validation( + threshold: float = 0.5, + multidim_average: str = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input.""" + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + allowed_multidim_average = ("global", "samplewise") + if not isinstance(multidim_average, str) and multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_stat_scores_tensor_validation( + preds: Tensor, target: Tensor, multidim_average: str = "global", ignore_index: Optional[int] = bool +) -> None: + """Validate tensor input.""" + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + "Detected the following values in `target`: {unique_values} but expected only" + " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + if multidim_average != "global" and preds.ndim < 2: + raise ValueError("Expected input to be atleast 2D when multidim_average is set to `samplewise`") + + +def _binary_stat_scores_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format.""" + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + if ignore_index is not None: + idx = target == ignore_index + target[idx] = -1 + + return preds, target + + +def _binary_stat_scores_update( + preds: Tensor, + target: Tensor, + multidim_average: str = "global", +) -> Tensor: + """""" + sum_dim = 0 if multidim_average == "global" else 1 + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn + + +def _binary_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + + +def binary_stat_scores( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) From 0ba326f5b64d951adf4df942e93f7f0fdb0609b9 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 30 May 2022 12:07:41 +0200 Subject: [PATCH 06/74] multiclass + multilabel confmat --- .../test_confusion_matrix.py | 44 +++++ .../classification_new/confusion_matrix.py | 80 ++++++++- .../classification_new/confusion_matrix.py | 165 ++++++++++++++++++ 3 files changed, 288 insertions(+), 1 deletion(-) diff --git a/tests/classification_new/test_confusion_matrix.py b/tests/classification_new/test_confusion_matrix.py index d7aa17d7f84..3088a64c91d 100644 --- a/tests/classification_new/test_confusion_matrix.py +++ b/tests/classification_new/test_confusion_matrix.py @@ -11,3 +11,47 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import pytest + +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester + + +@pytest.mark.parametrize( + "preds, target", + [ + (_input_binary_int.preds, _input_binary_int.target), + (_input_binary_prob.preds, _input_binary_prob.target), + (_input_binary_logit.preds, _input_binary_logit.target), + (_input_binary_int_multidim.preds, _input_binary_int_multidim.target), + (_input_binary_prob_multidim.preds, _input_binary_prob_multidim.target), + (_input_binary_logit_multidim.preds, _input_binary_logit_multidim.target), + ] +) +@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +class TestConfusionMatrix(MetricTester) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_confusion_matrix(self, preds, target, ddp, normalize): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryConfusionMatrix, + sk_metric=_sk_confusion_matrix_binary, + metric_args={ + "threshold": THRESHOLD + "normalize": normalize + } + ) + + def test_confusion_matrix_functional(self, preds, target, ddp, normalize): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_confusion_matrix, + sk_metric=_sk_confusion_matrix_binary, + metric_args={ + "threshold": THRESHOLD + "normalize": normalize + } + ) diff --git a/torchmetrics/classification_new/confusion_matrix.py b/torchmetrics/classification_new/confusion_matrix.py index bd27bcaddbe..1a69a58f572 100644 --- a/torchmetrics/classification_new/confusion_matrix.py +++ b/torchmetrics/classification_new/confusion_matrix.py @@ -22,6 +22,16 @@ _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_compute, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric @@ -35,7 +45,7 @@ def __init__( self, threshold: float = 0.5, ignore_index: Optional[int] = None, - normalize: bool = False, + normalize: Optional[str] = None, validate_args: bool = True, **kwargs: Dict[str, Any], ) -> None: @@ -58,3 +68,71 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore def compute(self) -> Tensor: return _binary_confusion_matrix_compute(self.confmat, self.normalize) + + +class MultiClassConfusionMatrix(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) + self.confmat += confmat + + def compute(self) -> Tensor: + return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) + + +class MultiLabelConfusionMatrix(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, ignore_index, normalize) + self.num_labels = num_labels + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_labels, 2, 2), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) + self.confmat += confmat + + def compute(self) -> Tensor: + return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) diff --git a/torchmetrics/functional/classification_new/confusion_matrix.py b/torchmetrics/functional/classification_new/confusion_matrix.py index f17ffcfd15f..8b7a40f0c3e 100644 --- a/torchmetrics/functional/classification_new/confusion_matrix.py +++ b/torchmetrics/functional/classification_new/confusion_matrix.py @@ -132,3 +132,168 @@ def binary_confusion_matrix( preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) confmat = _binary_confusion_matrix_update(preds, target) return _binary_confusion_matrix_compute(confmat, normalize) + + +def _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) -> None: + if not isinstance(num_classes, int) and num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) -> None: + """Validate tensor input.""" + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be equal to number of classes" + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + unique_values = torch.unique(target) + if ignore_index is None: + check = len(unique_values) > num_classes + else: + check = len(unique_values) > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found" + f"{len(unique_values)} in `target`." + ) + + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if len(unique_values) > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {len(unique_values)} in `preds`." + ) + + +def _multiclass_confusion_matrix_format(preds, target, ignore_index) -> Tuple[Tensor, Tensor]: + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1: + preds = preds.argmax(dim=1) + + preds = preds.flatten() + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + return preds, target + + +def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor: + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + return bins.reshape(num_classes, num_classes) + + +def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + return _confusion_matrix_reduce(confmat, normalize, multilabel=False) + + +def multiclass_confusion_matrix( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _multiclass_confusion_matrix_compute(confmat, normalize) + + +def _multilabel_confusion_matrix_arg_validation( + num_labels: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None +) -> None: + if not isinstance(num_labels, int) and num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_classes}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multilabel_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + pass + + +def _multilabel_confusion_matrix_format( + preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.sigmoid() + preds = preds > threshold + + preds = preds.movedim(1, -1).reshape(-1, num_labels) + target = target.movedim(1, -1).reshape(-1, num_labels) + + if ignore_index is not None: + # make sure that when we map, it will always result in a negative number that we can filter away + idx = target == ignore_index + preds[idx] = -4 * num_labels + target[idx] = -4 * num_labels + + return preds, target + + +def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: + unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() + unique_mapping = unique_mapping[unique_mapping > 0] + bins = _bincount(unique_mapping, minlength=4 * num_labels) + return bins.reshape(num_labels, 2, 2) + + +def _multilabel_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + return _confusion_matrix_reduce(confmat, normalize, multilabel=True) + + +def multilabel_confusion_matrix( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _multilabel_confusion_matrix_compute(confmat, normalize) From 85471c4bec48c264a555ecd907171d676c072278 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 30 May 2022 12:08:00 +0200 Subject: [PATCH 07/74] update --- tests/classification_new/test_confusion_matrix.py | 6 +++--- .../functional/classification_new/confusion_matrix.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/classification_new/test_confusion_matrix.py b/tests/classification_new/test_confusion_matrix.py index 3088a64c91d..a395ad71e6a 100644 --- a/tests/classification_new/test_confusion_matrix.py +++ b/tests/classification_new/test_confusion_matrix.py @@ -29,7 +29,7 @@ ] ) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -class TestConfusionMatrix(MetricTester) +class TestConfusionMatrix(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_binary_confusion_matrix(self, preds, target, ddp, normalize): self.run_class_metric_test( @@ -39,7 +39,7 @@ def test_binary_confusion_matrix(self, preds, target, ddp, normalize): metric_class=BinaryConfusionMatrix, sk_metric=_sk_confusion_matrix_binary, metric_args={ - "threshold": THRESHOLD + "threshold": THRESHOLD, "normalize": normalize } ) @@ -51,7 +51,7 @@ def test_confusion_matrix_functional(self, preds, target, ddp, normalize): metric_functional=binary_confusion_matrix, sk_metric=_sk_confusion_matrix_binary, metric_args={ - "threshold": THRESHOLD + "threshold": THRESHOLD, "normalize": normalize } ) diff --git a/torchmetrics/functional/classification_new/confusion_matrix.py b/torchmetrics/functional/classification_new/confusion_matrix.py index 8b7a40f0c3e..2c093c9c1de 100644 --- a/torchmetrics/functional/classification_new/confusion_matrix.py +++ b/torchmetrics/functional/classification_new/confusion_matrix.py @@ -151,7 +151,8 @@ def _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, i raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") if preds.shape[1] != num_classes: raise ValueError( - "If `preds` have one dimension more than `target`, `preds.shape[1]` should be equal to number of classes" + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." ) if preds.shape[2:] != target.shape[1:]: raise ValueError( @@ -237,7 +238,7 @@ def _multilabel_confusion_matrix_arg_validation( num_labels: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None ) -> None: if not isinstance(num_labels, int) and num_labels < 2: - raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_classes}") + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") allowed_normalize = ("true", "pred", "all", "none", None) From 336d277d64bf81ae7df736afb44a6901e6cada86 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 31 May 2022 09:08:33 +0200 Subject: [PATCH 08/74] stat_score --- .../classification_new/stat_scores.py | 85 ++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/torchmetrics/functional/classification_new/stat_scores.py b/torchmetrics/functional/classification_new/stat_scores.py index d1d6dc59452..414c64e1aae 100644 --- a/torchmetrics/functional/classification_new/stat_scores.py +++ b/torchmetrics/functional/classification_new/stat_scores.py @@ -28,7 +28,7 @@ def _binary_stat_scores_arg_validation( if not isinstance(threshold, float): raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") allowed_multidim_average = ("global", "samplewise") - if not isinstance(multidim_average, str) and multidim_average not in allowed_multidim_average: + if multidim_average not in allowed_multidim_average: raise ValueError( f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" ) @@ -112,7 +112,6 @@ def _binary_stat_scores_compute( return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) - def binary_stat_scores( preds: Tensor, target: Tensor, @@ -127,3 +126,85 @@ def binary_stat_scores( preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) + + +def _multiclass_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + + +def _multiclass_stat_scores_arg_validation( + num_classes: int, + top_k: int = 1, + average: str = 'micro', + multidim_average: str = "global", + ignore_index: Optional[int] = None, +) -> None: + if not isinstance(num_classes, int) and num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if not isinstance(top_k, int) and top_k < 1: + raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") + allowed_average = ("micro", "macro", "weighted,", "weighted", "samples", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average}, but got {average}" + ) + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multiclass_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_classes: int, + top_k: int = 1, + multidim_average: str = 'global', + ignore_index: Optional[int] = None, +) -> None: + pass + + +def _multiclass_stat_scores_format(): + pass + + +def _multiclass_stat_scores_update(): + + + +def multiclass_stat_scores( + preds: Tensor, + target: Tensor, + num_classes: int, + top_k: int = 1, + average: str = "micro", + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, top_k, average, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, ignore_index) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, multidim_average) + return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) + + + + + +def _multilabel_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + From b9fdf6c5dae54ef79297dd7d95630fb88b3ce247 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 2 Jun 2022 16:17:41 +0200 Subject: [PATCH 09/74] change bincount --- torchmetrics/utilities/compute.py | 18 ------------------ torchmetrics/utilities/data.py | 10 ++++------ 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/torchmetrics/utilities/compute.py b/torchmetrics/utilities/compute.py index 73faf05f550..f496baff818 100644 --- a/torchmetrics/utilities/compute.py +++ b/torchmetrics/utilities/compute.py @@ -38,21 +38,3 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: res = x * torch.log(y) res[x == 0] = 0.0 return res - - -def _bincount(x: Tensor, minlength: int) -> Tensor: - """``torch.bincount`` is currently slow on GPU and non-deterministic. This function instead relies on - broadcasting and summation which is fast but uses more memory. - - Args: - x: tensor to count - minlength: minimum length to count - - Returns: - Number of occurrences for each unique element in x - """ - if x.is_cuda: - labels = torch.arange(minlength, dtype=x.dtype, device=x.device).unsqueeze(1) - return (x.unsqueeze(0) == labels).sum(dim=-1) - else: - return torch.bincount(x, minlength=minlength) diff --git a/torchmetrics/utilities/data.py b/torchmetrics/utilities/data.py index abf19598343..84827d7876f 100644 --- a/torchmetrics/utilities/data.py +++ b/torchmetrics/utilities/data.py @@ -241,7 +241,7 @@ def _squeeze_if_scalar(data: Any) -> Any: return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor) -def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: +def _bincount(x: Tensor, minlength: int) -> Tensor: """``torch.bincount`` currently does not support deterministic mode on GPU. This implementation fallback to a for-loop counting occurrences in that case. @@ -253,15 +253,13 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: Returns: Number of occurrences for each unique element in x """ - if x.is_cuda and deterministic(): - if minlength is None: - minlength = len(torch.unique(x)) + if deterministic(): output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() return output - else: - return torch.bincount(x, minlength=minlength) + z = torch.zeros(minlength, device=x.device, dtype=x.dtype) + return z.index_add_(0, x, torch.ones_like(x)) def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: From e83c30a1bd797ed14a31753d5b6986e7a584b4a1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 2 Jun 2022 16:23:25 +0200 Subject: [PATCH 10/74] move back --- .../classification/confusion_matrix.py | 105 ++++++ torchmetrics/classification/stat_scores.py | 55 ++++ torchmetrics/classification_new/__init__.py | 13 - .../classification_new/confusion_matrix.py | 138 -------- .../classification_new/stat_scores.py | 78 ----- .../classification/confusion_matrix.py | 282 ++++++++++++++++ .../functional/classification/stat_scores.py | 192 +++++++++++ .../functional/classification_new/__init__.py | 13 - .../classification_new/confusion_matrix.py | 300 ------------------ .../classification_new/stat_scores.py | 210 ------------ 10 files changed, 634 insertions(+), 752 deletions(-) delete mode 100644 torchmetrics/classification_new/__init__.py delete mode 100644 torchmetrics/classification_new/confusion_matrix.py delete mode 100644 torchmetrics/classification_new/stat_scores.py delete mode 100644 torchmetrics/functional/classification_new/__init__.py delete mode 100644 torchmetrics/functional/classification_new/confusion_matrix.py delete mode 100644 torchmetrics/functional/classification_new/stat_scores.py diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index e4052e7e643..19b43ff4a0d 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -20,6 +20,111 @@ from torchmetrics.metric import Metric +class BinaryConfusionMatrix(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(2, 2), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + self.confmat += confmat + + def compute(self) -> Tensor: + return _binary_confusion_matrix_compute(self.confmat, self.normalize) + + +class MultiClassConfusionMatrix(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) + self.confmat += confmat + + def compute(self) -> Tensor: + return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) + + +class MultiLabelConfusionMatrix(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, ignore_index, normalize) + self.num_labels = num_labels + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_labels, 2, 2), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) + self.confmat += confmat + + def compute(self) -> Tensor: + return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) + + +# -------------------------- Old stuff -------------------------- + + class ConfusionMatrix(Metric): r"""Computes the `confusion matrix`_. diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index ead35d1c7cb..0b1290844db 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -21,6 +21,61 @@ from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +class BinaryStatScores(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + self.threshold = threshold + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + if self.multidim_average == "samplewise": + self.add_state("tp", [], dist_reduce_fx="cat") + self.add_state("fp", [], dist_reduce_fx="cat") + self.add_state("tn", [], dist_reduce_fx="cat") + self.add_state("fn", [], dist_reduce_fx="cat") + else: + self.add_state("tp", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("fp", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("tn", torch.zeros(1), dist_reduce_fx="sum") + self.add_state("fn", torch.zeros(1), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) + preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) + if self.multidim_average == "samplewise": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def compute(self) -> Tensor: + return _binary_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + + +# -------------------------- Old stuff -------------------------- + + class StatScores(Metric): r"""Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors`_ and the `confusion matrix`_. diff --git a/torchmetrics/classification_new/__init__.py b/torchmetrics/classification_new/__init__.py deleted file mode 100644 index d7aa17d7f84..00000000000 --- a/torchmetrics/classification_new/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/torchmetrics/classification_new/confusion_matrix.py b/torchmetrics/classification_new/confusion_matrix.py deleted file mode 100644 index 1a69a58f572..00000000000 --- a/torchmetrics/classification_new/confusion_matrix.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional - -import torch -from torch import Tensor - -from torchmetrics.functional.classification_new.confusion_matrix import ( - _binary_confusion_matrix_arg_validation, - _binary_confusion_matrix_compute, - _binary_confusion_matrix_format, - _binary_confusion_matrix_tensor_validation, - _binary_confusion_matrix_update, - _multiclass_confusion_matrix_arg_validation, - _multiclass_confusion_matrix_compute, - _multiclass_confusion_matrix_format, - _multiclass_confusion_matrix_tensor_validation, - _multiclass_confusion_matrix_update, - _multilabel_confusion_matrix_arg_validation, - _multilabel_confusion_matrix_compute, - _multilabel_confusion_matrix_format, - _multilabel_confusion_matrix_tensor_validation, - _multilabel_confusion_matrix_update, -) -from torchmetrics.metric import Metric - - -class BinaryConfusionMatrix(Metric): - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - - def __init__( - self, - threshold: float = 0.5, - ignore_index: Optional[int] = None, - normalize: Optional[str] = None, - validate_args: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__(**kwargs) - if validate_args: - _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) - self.threshold = threshold - self.ignore_index = ignore_index - self.normalize = normalize - self.validate_args = validate_args - - self.add_state("confmat", torch.zeros(2, 2), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - if self.validate_args: - _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) - confmat = _binary_confusion_matrix_update(preds, target) - self.confmat += confmat - - def compute(self) -> Tensor: - return _binary_confusion_matrix_compute(self.confmat, self.normalize) - - -class MultiClassConfusionMatrix(Metric): - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - - def __init__( - self, - num_classes: int, - ignore_index: Optional[int] = None, - normalize: Optional[str] = None, - validate_args: bool = True, - **kwargs, - ) -> None: - super().__init__(**kwargs) - if validate_args: - _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) - self.num_classes = num_classes - self.ignore_index = ignore_index - self.normalize = normalize - self.validate_args = validate_args - - self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - if self.validate_args: - _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) - preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) - confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) - self.confmat += confmat - - def compute(self) -> Tensor: - return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) - - -class MultiLabelConfusionMatrix(Metric): - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - - def __init__( - self, - num_labels: int, - ignore_index: Optional[int] = None, - normalize: Optional[str] = None, - validate_args: bool = True, - **kwargs, - ) -> None: - super().__init__(**kwargs) - if validate_args: - _multilabel_confusion_matrix_arg_validation(num_labels, ignore_index, normalize) - self.num_labels = num_labels - self.ignore_index = ignore_index - self.normalize = normalize - self.validate_args = validate_args - - self.add_state("confmat", torch.zeros(num_labels, 2, 2), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - if self.validate_args: - _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) - preds, target = _multilabel_confusion_matrix_format(preds, target, self.ignore_index) - confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) - self.confmat += confmat - - def compute(self) -> Tensor: - return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) diff --git a/torchmetrics/classification_new/stat_scores.py b/torchmetrics/classification_new/stat_scores.py deleted file mode 100644 index d65694cb140..00000000000 --- a/torchmetrics/classification_new/stat_scores.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any, Dict, Optional - -import torch -from torch import Tensor - -from torchmetrics.functional.classification_new.stat_scores import ( - _binary_stat_scores_arg_validation, - _binary_stat_scores_compute, - _binary_stat_scores_format, - _binary_stat_scores_tensor_validation, - _binary_stat_scores_update, -) -from torchmetrics.metric import Metric - - -class BinaryStatScores(Metric): - is_differentiable: bool = False - higher_is_better: Optional[bool] = None - full_state_update: bool = False - - def __init__( - self, - threshold: float = 0.5, - multidim_average: str = "global", - ignore_index: Optional[int] = None, - validate_args: bool = True, - **kwargs: Dict[str, Any], - ) -> None: - super().__init__(**kwargs) - if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - self.threshold = threshold - self.multidim_average = multidim_average - self.ignore_index = ignore_index - self.validate_args = validate_args - - if self.multidim_average == "samplewise": - self.add_state("tp", [], dist_reduce_fx="cat") - self.add_state("fp", [], dist_reduce_fx="cat") - self.add_state("tn", [], dist_reduce_fx="cat") - self.add_state("fn", [], dist_reduce_fx="cat") - else: - self.add_state("tp", torch.zeros(1), dist_reduce_fx="sum") - self.add_state("fp", torch.zeros(1), dist_reduce_fx="sum") - self.add_state("tn", torch.zeros(1), dist_reduce_fx="sum") - self.add_state("fn", torch.zeros(1), dist_reduce_fx="sum") - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - if self.validate_args: - _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) - preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) - tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) - if self.multidim_average == "samplewise": - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - else: - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn - - def compute(self) -> Tensor: - return _binary_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index 362276dc146..7e3b6217343 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -22,6 +22,288 @@ from torchmetrics.utilities.enums import DataType +def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[str] = None, multilabel: bool = False) -> Tensor: + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") + if normalize is not None and normalize != "none": + confmat = confmat.float() if not confmat.is_floating_point() else confmat + if normalize == "true": + confmat = confmat / confmat.sum(axis=2 if multilabel else 1, keepdim=True) + elif normalize == "pred": + confmat = confmat / confmat.sum(axis=1 if multilabel else 0, keepdim=True) + elif normalize == "all": + confmat = confmat / confmat.sum(axis=[1, 2] if multilabel else [0, 1]) + + nan_elements = confmat[torch.isnan(confmat)].nelement() + if nan_elements != 0: + confmat[torch.isnan(confmat)] = 0 + rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.") + return confmat + + +def _binary_confusion_matrix_arg_validation( + threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None +) -> None: + """Validate non tensor input.""" + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _binary_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = bool +) -> None: + """Validate tensor input.""" + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + "Detected the following values in `target`: {unique_values} but expected only" + " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + +def _binary_confusion_matrix_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format.""" + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + return preds, target + + +def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: + """Calculate confusion matrix on current input.""" + unique_mapping = (target * 2 + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=4) + return bins.reshape(2, 2) + + +def _binary_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + """Calculate final confusion matrix.""" + return _confusion_matrix_reduce(confmat, normalize, multilabel=False) + + +def binary_confusion_matrix( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _binary_confusion_matrix_compute(confmat, normalize) + + +def _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) -> None: + if not isinstance(num_classes, int) and num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) -> None: + """Validate tensor input.""" + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + unique_values = torch.unique(target) + if ignore_index is None: + check = len(unique_values) > num_classes + else: + check = len(unique_values) > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found" + f"{len(unique_values)} in `target`." + ) + + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if len(unique_values) > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {len(unique_values)} in `preds`." + ) + + +def _multiclass_confusion_matrix_format(preds, target, ignore_index) -> Tuple[Tensor, Tensor]: + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1: + preds = preds.argmax(dim=1) + + preds = preds.flatten() + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + return preds, target + + +def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor: + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + return bins.reshape(num_classes, num_classes) + + +def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + return _confusion_matrix_reduce(confmat, normalize, multilabel=False) + + +def multiclass_confusion_matrix( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _multiclass_confusion_matrix_compute(confmat, normalize) + + +def _multilabel_confusion_matrix_arg_validation( + num_labels: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None +) -> None: + if not isinstance(num_labels, int) and num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multilabel_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + pass + + +def _multilabel_confusion_matrix_format( + preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.sigmoid() + preds = preds > threshold + + preds = preds.movedim(1, -1).reshape(-1, num_labels) + target = target.movedim(1, -1).reshape(-1, num_labels) + + if ignore_index is not None: + # make sure that when we map, it will always result in a negative number that we can filter away + idx = target == ignore_index + preds[idx] = -4 * num_labels + target[idx] = -4 * num_labels + + return preds, target + + +def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: + unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() + unique_mapping = unique_mapping[unique_mapping > 0] + bins = _bincount(unique_mapping, minlength=4 * num_labels) + return bins.reshape(num_labels, 2, 2) + + +def _multilabel_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + return _confusion_matrix_reduce(confmat, normalize, multilabel=True) + + +def multilabel_confusion_matrix( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[str] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _multilabel_confusion_matrix_compute(confmat, normalize) + + +# -------------------------- Old stuff -------------------------- + + def _confusion_matrix_update( preds: Tensor, target: Tensor, num_classes: int, threshold: float = 0.5, multilabel: bool = False ) -> Tensor: diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index b3cb7786e49..f48bc364f97 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -20,6 +20,198 @@ from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +def _binary_stat_scores_arg_validation( + threshold: float = 0.5, + multidim_average: str = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input.""" + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_stat_scores_tensor_validation( + preds: Tensor, target: Tensor, multidim_average: str = "global", ignore_index: Optional[int] = bool +) -> None: + """Validate tensor input.""" + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + "Detected the following values in `target`: {unique_values} but expected only" + " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + if multidim_average != "global" and preds.ndim < 2: + raise ValueError("Expected input to be atleast 2D when multidim_average is set to `samplewise`") + + +def _binary_stat_scores_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format.""" + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + if ignore_index is not None: + idx = target == ignore_index + target[idx] = -1 + + return preds, target + + +def _binary_stat_scores_update( + preds: Tensor, + target: Tensor, + multidim_average: str = "global", +) -> Tensor: + """""" + sum_dim = 0 if multidim_average == "global" else 1 + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn + + +def _binary_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + +def binary_stat_scores( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) + + +def _multiclass_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + + +def _multiclass_stat_scores_arg_validation( + num_classes: int, + top_k: int = 1, + average: str = 'micro', + multidim_average: str = "global", + ignore_index: Optional[int] = None, +) -> None: + if not isinstance(num_classes, int) and num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if not isinstance(top_k, int) and top_k < 1: + raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") + allowed_average = ("micro", "macro", "weighted,", "weighted", "samples", "none", None) + if average not in allowed_average: + raise ValueError( + f"Expected argument `average` to be one of {allowed_average}, but got {average}" + ) + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multiclass_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_classes: int, + top_k: int = 1, + multidim_average: str = 'global', + ignore_index: Optional[int] = None, +) -> None: + pass + + +def _multiclass_stat_scores_format(): + pass + + +def _multiclass_stat_scores_update(): + + + +def multiclass_stat_scores( + preds: Tensor, + target: Tensor, + num_classes: int, + top_k: int = 1, + average: str = "micro", + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, top_k, average, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, ignore_index) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, multidim_average) + return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) + + + + + +def _multilabel_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + +# -------------------------- Old stuff -------------------------- + + def _del_column(data: Tensor, idx: int) -> Tensor: """Delete the column at index.""" return torch.cat([data[:, :idx], data[:, (idx + 1) :]], 1) diff --git a/torchmetrics/functional/classification_new/__init__.py b/torchmetrics/functional/classification_new/__init__.py deleted file mode 100644 index d7aa17d7f84..00000000000 --- a/torchmetrics/functional/classification_new/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/torchmetrics/functional/classification_new/confusion_matrix.py b/torchmetrics/functional/classification_new/confusion_matrix.py deleted file mode 100644 index 2c093c9c1de..00000000000 --- a/torchmetrics/functional/classification_new/confusion_matrix.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Tuple - -import torch -from torch import Tensor - -from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _bincount -from torchmetrics.utilities.prints import rank_zero_warn - - -def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[str] = None, multilabel: bool = False) -> Tensor: - allowed_normalize = ("true", "pred", "all", "none", None) - if normalize not in allowed_normalize: - raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") - if normalize is not None and normalize != "none": - confmat = confmat.float() if not confmat.is_floating_point() else confmat - if normalize == "true": - confmat = confmat / confmat.sum(axis=2 if multilabel else 1, keepdim=True) - elif normalize == "pred": - confmat = confmat / confmat.sum(axis=1 if multilabel else 0, keepdim=True) - elif normalize == "all": - confmat = confmat / confmat.sum(axis=[1, 2] if multilabel else [0, 1]) - - nan_elements = confmat[torch.isnan(confmat)].nelement() - if nan_elements != 0: - confmat[torch.isnan(confmat)] = 0 - rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.") - return confmat - - -def _binary_confusion_matrix_arg_validation( - threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None -) -> None: - """Validate non tensor input.""" - if not isinstance(threshold, float): - raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") - if ignore_index is not None and not isinstance(ignore_index, int): - raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") - allowed_normalize = ("true", "pred", "all", "none", None) - if normalize not in allowed_normalize: - raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") - - -def _binary_confusion_matrix_tensor_validation( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = bool -) -> None: - """Validate tensor input.""" - # Check that they have same shape - _check_same_shape(preds, target) - - # Check that target only contains [0,1] values or value in ignore_index - unique_values = torch.unique(target) - if ignore_index is None: - check = torch.any((unique_values != 0) & (unique_values != 1)) - else: - check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) - if check: - raise RuntimeError( - "Detected the following values in `target`: {unique_values} but expected only" - " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." - ) - - # If preds is label tensor, also check that it only contains [0,1] values - if not preds.is_floating_point(): - unique_values = torch.unique(preds) - if torch.any((unique_values != 0) & (unique_values != 1)): - raise RuntimeError( - "Detected the following values in `preds`: {unique_values} but expected only" - " the following values [0,1] since preds is a label tensor." - ) - - -def _binary_confusion_matrix_format( - preds: Tensor, - target: Tensor, - threshold: float = 0.5, - ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: - """Convert all input to label format.""" - preds = preds.flatten() - target = target.flatten() - if ignore_index is not None: - idx = target != ignore_index - preds = preds[idx] - target = target[idx] - - if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() - preds = preds > threshold - - return preds, target - - -def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: - """Calculate confusion matrix on current input.""" - unique_mapping = (target * 2 + preds).to(torch.long) - bins = _bincount(unique_mapping, minlength=4) - return bins.reshape(2, 2) - - -def _binary_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - """Calculate final confusion matrix.""" - return _confusion_matrix_reduce(confmat, normalize, multilabel=False) - - -def binary_confusion_matrix( - preds: Tensor, - target: Tensor, - threshold: float = 0.5, - ignore_index: Optional[int] = None, - normalize: Optional[str] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) - _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) - preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) - confmat = _binary_confusion_matrix_update(preds, target) - return _binary_confusion_matrix_compute(confmat, normalize) - - -def _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) -> None: - if not isinstance(num_classes, int) and num_classes < 2: - raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") - if ignore_index is not None and not isinstance(ignore_index, int): - raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") - allowed_normalize = ("true", "pred", "all", "none", None) - if normalize not in allowed_normalize: - raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") - - -def _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) -> None: - """Validate tensor input.""" - if preds.ndim == target.ndim + 1: - if not preds.is_floating_point(): - raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if preds.shape[1] != num_classes: - raise ValueError( - "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" - " equal to number of classes." - ) - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "If `preds` have one dimension more than `target`, the shape of `preds` should be" - " (N, C, ...), and the shape of `target` should be (N, ...)." - ) - elif preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "The `preds` and `target` should have the same shape,", - f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", - ) - else: - raise ValueError( - "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" - " and `preds` should be (N, C, ...)." - ) - - unique_values = torch.unique(target) - if ignore_index is None: - check = len(unique_values) > num_classes - else: - check = len(unique_values) > num_classes + 1 - if check: - raise RuntimeError( - "Detected more unique values in `target` than `num_classes`. Expected only " - f"{num_classes if ignore_index is None else num_classes + 1} but found" - f"{len(unique_values)} in `target`." - ) - - if not preds.is_floating_point(): - unique_values = torch.unique(preds) - if len(unique_values) > num_classes: - raise RuntimeError( - "Detected more unique values in `preds` than `num_classes`. Expected only " - f"{num_classes} but found {len(unique_values)} in `preds`." - ) - - -def _multiclass_confusion_matrix_format(preds, target, ignore_index) -> Tuple[Tensor, Tensor]: - # Apply argmax if we have one more dimension - if preds.ndim == target.ndim + 1: - preds = preds.argmax(dim=1) - - preds = preds.flatten() - target = target.flatten() - - if ignore_index is not None: - idx = target != ignore_index - preds = preds[idx] - target = target[idx] - - return preds, target - - -def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor: - unique_mapping = (target * num_classes + preds).to(torch.long) - bins = _bincount(unique_mapping, minlength=num_classes**2) - return bins.reshape(num_classes, num_classes) - - -def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - return _confusion_matrix_reduce(confmat, normalize, multilabel=False) - - -def multiclass_confusion_matrix( - preds: Tensor, - target: Tensor, - num_classes: int, - ignore_index: Optional[int] = None, - normalize: Optional[str] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) - _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) - preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) - confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) - return _multiclass_confusion_matrix_compute(confmat, normalize) - - -def _multilabel_confusion_matrix_arg_validation( - num_labels: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None -) -> None: - if not isinstance(num_labels, int) and num_labels < 2: - raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") - if ignore_index is not None and not isinstance(ignore_index, int): - raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") - allowed_normalize = ("true", "pred", "all", "none", None) - if normalize not in allowed_normalize: - raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") - - -def _multilabel_confusion_matrix_tensor_validation( - preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None -) -> None: - pass - - -def _multilabel_confusion_matrix_format( - preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None -) -> Tuple[Tensor, Tensor]: - if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): - preds = preds.sigmoid() - preds = preds > threshold - - preds = preds.movedim(1, -1).reshape(-1, num_labels) - target = target.movedim(1, -1).reshape(-1, num_labels) - - if ignore_index is not None: - # make sure that when we map, it will always result in a negative number that we can filter away - idx = target == ignore_index - preds[idx] = -4 * num_labels - target[idx] = -4 * num_labels - - return preds, target - - -def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: - unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() - unique_mapping = unique_mapping[unique_mapping > 0] - bins = _bincount(unique_mapping, minlength=4 * num_labels) - return bins.reshape(num_labels, 2, 2) - - -def _multilabel_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - return _confusion_matrix_reduce(confmat, normalize, multilabel=True) - - -def multilabel_confusion_matrix( - preds: Tensor, - target: Tensor, - num_labels: int, - threshold: float = 0.5, - ignore_index: Optional[int] = None, - normalize: Optional[str] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) - _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) - preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) - confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) - return _multilabel_confusion_matrix_compute(confmat, normalize) diff --git a/torchmetrics/functional/classification_new/stat_scores.py b/torchmetrics/functional/classification_new/stat_scores.py deleted file mode 100644 index 414c64e1aae..00000000000 --- a/torchmetrics/functional/classification_new/stat_scores.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Tuple - -import torch -from torch import Tensor - -from torchmetrics.utilities.checks import _check_same_shape - - -def _binary_stat_scores_arg_validation( - threshold: float = 0.5, - multidim_average: str = "global", - ignore_index: Optional[int] = None, -) -> None: - """Validate non tensor input.""" - if not isinstance(threshold, float): - raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") - allowed_multidim_average = ("global", "samplewise") - if multidim_average not in allowed_multidim_average: - raise ValueError( - f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" - ) - if ignore_index is not None and not isinstance(ignore_index, int): - raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") - - -def _binary_stat_scores_tensor_validation( - preds: Tensor, target: Tensor, multidim_average: str = "global", ignore_index: Optional[int] = bool -) -> None: - """Validate tensor input.""" - # Check that they have same shape - _check_same_shape(preds, target) - - # Check that target only contains [0,1] values or value in ignore_index - unique_values = torch.unique(target) - if ignore_index is None: - check = torch.any((unique_values != 0) & (unique_values != 1)) - else: - check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) - if check: - raise RuntimeError( - "Detected the following values in `target`: {unique_values} but expected only" - " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." - ) - - # If preds is label tensor, also check that it only contains [0,1] values - if not preds.is_floating_point(): - unique_values = torch.unique(preds) - if torch.any((unique_values != 0) & (unique_values != 1)): - raise RuntimeError( - "Detected the following values in `preds`: {unique_values} but expected only" - " the following values [0,1] since preds is a label tensor." - ) - - if multidim_average != "global" and preds.ndim < 2: - raise ValueError("Expected input to be atleast 2D when multidim_average is set to `samplewise`") - - -def _binary_stat_scores_format( - preds: Tensor, - target: Tensor, - threshold: float = 0.5, - ignore_index: Optional[int] = None, -) -> Tuple[Tensor, Tensor]: - """Convert all input to label format.""" - if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): - # preds is logits, convert with sigmoid - preds = preds.sigmoid() - preds = preds > threshold - - preds = preds.reshape(preds.shape[0], -1) - target = target.reshape(target.shape[0], -1) - - if ignore_index is not None: - idx = target == ignore_index - target[idx] = -1 - - return preds, target - - -def _binary_stat_scores_update( - preds: Tensor, - target: Tensor, - multidim_average: str = "global", -) -> Tensor: - """""" - sum_dim = 0 if multidim_average == "global" else 1 - tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() - fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() - fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() - tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() - return tp, fp, tn, fn - - -def _binary_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" -) -> Tensor: - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) - -def binary_stat_scores( - preds: Tensor, - target: Tensor, - threshold: float = 0.5, - multidim_average: str = "global", - ignore_index: Optional[int] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) - _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) - preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) - tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) - return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) - - -def _multiclass_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" -) -> Tensor: - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) - - -def _multiclass_stat_scores_arg_validation( - num_classes: int, - top_k: int = 1, - average: str = 'micro', - multidim_average: str = "global", - ignore_index: Optional[int] = None, -) -> None: - if not isinstance(num_classes, int) and num_classes < 2: - raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") - if not isinstance(top_k, int) and top_k < 1: - raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") - allowed_average = ("micro", "macro", "weighted,", "weighted", "samples", "none", None) - if average not in allowed_average: - raise ValueError( - f"Expected argument `average` to be one of {allowed_average}, but got {average}" - ) - allowed_multidim_average = ("global", "samplewise") - if multidim_average not in allowed_multidim_average: - raise ValueError( - f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" - ) - if ignore_index is not None and not isinstance(ignore_index, int): - raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") - - -def _multiclass_stat_scores_tensor_validation( - preds: Tensor, - target: Tensor, - num_classes: int, - top_k: int = 1, - multidim_average: str = 'global', - ignore_index: Optional[int] = None, -) -> None: - pass - - -def _multiclass_stat_scores_format(): - pass - - -def _multiclass_stat_scores_update(): - - - -def multiclass_stat_scores( - preds: Tensor, - target: Tensor, - num_classes: int, - top_k: int = 1, - average: str = "micro", - multidim_average: str = "global", - ignore_index: Optional[int] = None, - validate_args: bool = True, -) -> Tensor: - if validate_args: - _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, top_k, average, multidim_average, ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, ignore_index) - tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, multidim_average) - return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) - - - - - -def _multilabel_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" -) -> Tensor: - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) - From b32000f414c01594ad273d216da0fa8f1f2f1625 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 2 Jun 2022 16:33:52 +0200 Subject: [PATCH 11/74] del tests --- tests/classification/test_confusion_matrix.py | 41 +++++++++++++ tests/classification_new/__init__.py | 13 ----- .../test_confusion_matrix.py | 57 ------------------- tests/classification_new/test_stat_scores.py | 13 ----- .../functional/classification/__init__.py | 14 ++++- .../classification/confusion_matrix.py | 11 +++- .../functional/classification/stat_scores.py | 17 ++++-- 7 files changed, 75 insertions(+), 91 deletions(-) delete mode 100644 tests/classification_new/__init__.py delete mode 100644 tests/classification_new/test_confusion_matrix.py delete mode 100644 tests/classification_new/test_stat_scores.py diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 08ce3e3fb4d..f45c7c30c29 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -36,6 +36,47 @@ seed_all(42) +@pytest.mark.parametrize( + "preds, target", + [ + (_input_binary_int.preds, _input_binary_int.target), + (_input_binary_prob.preds, _input_binary_prob.target), + (_input_binary_logit.preds, _input_binary_logit.target), + (_input_binary_int_multidim.preds, _input_binary_int_multidim.target), + (_input_binary_prob_multidim.preds, _input_binary_prob_multidim.target), + (_input_binary_logit_multidim.preds, _input_binary_logit_multidim.target), + ] +) +@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +class TestBinaryConfusionMatrix(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_confusion_matrix(self, preds, target, ddp, normalize): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryConfusionMatrix, + sk_metric=_sk_confusion_matrix_binary, + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize + } + ) + + def test_confusion_matrix_functional(self, preds, target, ddp, normalize): + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_confusion_matrix, + sk_metric=_sk_confusion_matrix_binary, + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize + } + ) + +# -------------------------- Old stuff -------------------------- + def _sk_cm_binary_prob(preds, target, normalize=None): sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) sk_target = target.view(-1).numpy() diff --git a/tests/classification_new/__init__.py b/tests/classification_new/__init__.py deleted file mode 100644 index d7aa17d7f84..00000000000 --- a/tests/classification_new/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/classification_new/test_confusion_matrix.py b/tests/classification_new/test_confusion_matrix.py deleted file mode 100644 index a395ad71e6a..00000000000 --- a/tests/classification_new/test_confusion_matrix.py +++ /dev/null @@ -1,57 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester - - -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_int.preds, _input_binary_int.target), - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary_logit.preds, _input_binary_logit.target), - (_input_binary_int_multidim.preds, _input_binary_int_multidim.target), - (_input_binary_prob_multidim.preds, _input_binary_prob_multidim.target), - (_input_binary_logit_multidim.preds, _input_binary_logit_multidim.target), - ] -) -@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -class TestConfusionMatrix(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - def test_binary_confusion_matrix(self, preds, target, ddp, normalize): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=BinaryConfusionMatrix, - sk_metric=_sk_confusion_matrix_binary, - metric_args={ - "threshold": THRESHOLD, - "normalize": normalize - } - ) - - def test_confusion_matrix_functional(self, preds, target, ddp, normalize): - self.run_functional_metric_test( - preds=preds, - target=target, - metric_functional=binary_confusion_matrix, - sk_metric=_sk_confusion_matrix_binary, - metric_args={ - "threshold": THRESHOLD, - "normalize": normalize - } - ) diff --git a/tests/classification_new/test_stat_scores.py b/tests/classification_new/test_stat_scores.py deleted file mode 100644 index d7aa17d7f84..00000000000 --- a/tests/classification_new/test_stat_scores.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index 70f777b56e0..3eddbe133d3 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -17,7 +17,12 @@ from torchmetrics.functional.classification.average_precision import average_precision # noqa: F401 from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401 +from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 + confusion_matrix, + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401 from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score # noqa: F401 from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401 @@ -34,4 +39,9 @@ ) from torchmetrics.functional.classification.roc import roc # noqa: F401 from torchmetrics.functional.classification.specificity import specificity # noqa: F401 -from torchmetrics.functional.classification.stat_scores import stat_scores # noqa: F401 +from torchmetrics.functional.classification.stat_scores import ( # noqa: F401 + stat_scores, + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, +) diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index 7e3b6217343..c9cbbeabde6 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -16,8 +16,8 @@ import torch from torch import Tensor -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.prints import rank_zero_warn +from torchmetrics.utilities.checks import _input_format_classification, _check_same_shape from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import DataType @@ -464,5 +464,12 @@ def confusion_matrix( [[0, 1], [0, 1]]]) """ + rank_zero_warn( + "`torchmetrics.functional.confusion_matrix` have been deprecated in v0.10 in favor of" + "`torchmetrics.functional.binary_confusion_matrix`, `torchmetrics.functional.multiclass_confusion_matrix`" + "and `torchmetrics.functional.multilabel_confusion_matrix`. Please upgrade to the version that matches" + "your problem (API may have changed). This function will be removed v0.11.", + DeprecationWarning + ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel) return _confusion_matrix_compute(confmat, normalize) diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index f48bc364f97..0b58c383aaa 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -16,8 +16,9 @@ import torch from torch import Tensor, tensor -from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.checks import _input_format_classification, _check_same_shape from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn def _binary_stat_scores_arg_validation( @@ -178,7 +179,7 @@ def _multiclass_stat_scores_format(): def _multiclass_stat_scores_update(): - + pass def multiclass_stat_scores( @@ -200,8 +201,6 @@ def multiclass_stat_scores( - - def _multilabel_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" ) -> Tensor: @@ -209,6 +208,9 @@ def _multilabel_stat_scores_compute( return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) +def multilabel_stat_scores(): + pass + # -------------------------- Old stuff -------------------------- @@ -608,6 +610,13 @@ def stat_scores( tensor([2, 2, 6, 2, 4]) """ + rank_zero_warn( + "`torchmetrics.functional.stat_scores` have been deprecated in v0.10 in favor of" + "`torchmetrics.functional.binary_stat_scores`, `torchmetrics.functional.multiclass_stat_scores`" + "and `torchmetrics.functional.multilabel_stat_scores`. Please upgrade to the version that matches" + "your problem (API may have changed). This function will be removed v0.11.", + DeprecationWarning + ) if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") From 6d5ee7069f3200b198db7028ca0780e5515dc37d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 2 Jun 2022 17:09:45 +0200 Subject: [PATCH 12/74] rest of structure --- torchmetrics/classification/__init__.py | 13 +- .../classification/confusion_matrix.py | 24 ++- torchmetrics/classification/stat_scores.py | 153 ++++++++++++++++-- torchmetrics/functional/__init__.py | 21 ++- .../functional/classification/stat_scores.py | 33 ++-- 5 files changed, 219 insertions(+), 25 deletions(-) diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 70ae4d5179c..f826af79d94 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -20,7 +20,11 @@ from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 -from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 +from torchmetrics.classification.confusion_matrix import ( # noqa: F401 + ConfusionMatrix, + BinaryConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 @@ -37,4 +41,9 @@ ) from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.specificity import Specificity # noqa: F401 -from torchmetrics.classification.stat_scores import StatScores # noqa: F401 +from torchmetrics.classification.stat_scores import ( # noqa: F401 + StatScores, + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, +) diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 19b43ff4a0d..3779b9525f7 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -16,7 +16,25 @@ import torch from torch import Tensor -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update +from torchmetrics.functional.classification.confusion_matrix import ( + _confusion_matrix_compute, + _confusion_matrix_update, + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_compute, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_compute, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update +) from torchmetrics.metric import Metric @@ -54,7 +72,7 @@ def compute(self) -> Tensor: return _binary_confusion_matrix_compute(self.confmat, self.normalize) -class MultiClassConfusionMatrix(Metric): +class MulticlassConfusionMatrix(Metric): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -88,7 +106,7 @@ def compute(self) -> Tensor: return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) -class MultiLabelConfusionMatrix(Metric): +class MultilabelConfusionMatrix(Metric): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 0b1290844db..4a37854d22c 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -16,7 +16,25 @@ import torch from torch import Tensor -from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update +from torchmetrics.functional.classification.stat_scores import ( + _stat_scores_compute, + _stat_scores_update, + _binary_stat_scores_arg_validation, + _binary_stat_scores_compute, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_compute, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_compute, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update +) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod @@ -43,15 +61,15 @@ def __init__( self.validate_args = validate_args if self.multidim_average == "samplewise": - self.add_state("tp", [], dist_reduce_fx="cat") - self.add_state("fp", [], dist_reduce_fx="cat") - self.add_state("tn", [], dist_reduce_fx="cat") - self.add_state("fn", [], dist_reduce_fx="cat") + default = lambda : [ ] + dist_reduce_fx = "cat" else: - self.add_state("tp", torch.zeros(1), dist_reduce_fx="sum") - self.add_state("fp", torch.zeros(1), dist_reduce_fx="sum") - self.add_state("tn", torch.zeros(1), dist_reduce_fx="sum") - self.add_state("fn", torch.zeros(1), dist_reduce_fx="sum") + default = lambda : torch.zeros(1) + dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: @@ -73,6 +91,123 @@ def compute(self) -> Tensor: return _binary_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) +class MulticlassStatScores(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + average: str = "micro", + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super.__init__(**kwargs) + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + self.num_classes = num_classes + self.top_k = top_k + self.average = average + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + if self.multidim_average == "samplewise" or self.average == "samples": + default = lambda : [ ] + dist_reduce_fx = "cat" + elif self.average != "micro": + default = lambda : torch.zeros(num_classes) + dist_reduce_fx = "sum" + else: + default = lambda : torch.zeros(1) + dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multiclass_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, self.threshold, self.ignore_index) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, self.multidim_average) + if self.multidim_average == "samplewise" or self.average == "samples": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def compute(self) -> Tensor: + return _multiclass_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + + +class MultilabelStatScores(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + average: str = "micro", + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Dict[str, Any], + ) -> None: + super.__init__(**kwargs) + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + self.num_labels = num_labels + self.threshold = threshold + self.average = average + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + if self.multidim_average == "samplewise" or self.average == "samples": + default = lambda : [ ] + dist_reduce_fx = "cat" + elif self.average != "micro": + default = lambda : torch.zeros(num_labels) + dist_reduce_fx = "sum" + else: + default = lambda : torch.zeros(1) + dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + if self.validate_args: + _multilabel_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, self.threshold, self.ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, self.multidim_average) + if self.multidim_average == "samplewise" or self.average == "samples": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def compute(self) -> Tensor: + return _multilabel_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + # -------------------------- Old stuff -------------------------- diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 08f0ac2e46c..67e67e373e0 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -20,7 +20,12 @@ from torchmetrics.functional.classification.average_precision import average_precision from torchmetrics.functional.classification.calibration_error import calibration_error from torchmetrics.functional.classification.cohen_kappa import cohen_kappa -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix +from torchmetrics.functional.classification.confusion_matrix import ( + confusion_matrix, + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) from torchmetrics.functional.classification.dice import dice, dice_score from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score from torchmetrics.functional.classification.hamming import hamming_distance @@ -37,7 +42,12 @@ ) from torchmetrics.functional.classification.roc import roc from torchmetrics.functional.classification.specificity import specificity -from torchmetrics.functional.classification.stat_scores import stat_scores +from torchmetrics.functional.classification.stat_scores import ( + stat_scores, + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, +) from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.functional.image.gradients import image_gradients @@ -168,4 +178,11 @@ "word_error_rate", "word_information_lost", "word_information_preserved", +] + [ + "binary_confusion_matrix", + "multiclass_confusion_matrix", + "multilabel_confusion_matrix", + "binary_stat_scores", + "multiclass_stat_scores", + "multilabel_stat_scores", ] diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 0b58c383aaa..fd8549fb6a4 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -130,14 +130,6 @@ def binary_stat_scores( return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) -def _multiclass_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" -) -> Tensor: - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) - - def _multiclass_stat_scores_arg_validation( num_classes: int, top_k: int = 1, @@ -182,6 +174,14 @@ def _multiclass_stat_scores_update(): pass +def _multiclass_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" +) -> Tensor: + if multidim_average == "global": + return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + + def multiclass_stat_scores( preds: Tensor, target: Tensor, @@ -199,7 +199,17 @@ def multiclass_stat_scores( tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, multidim_average) return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) +def _multilabel_stat_scores_arg_validation(): + pass + +def _multilabel_stat_scores_tensor_validation(): + pass + +def _multilabel_stat_scores_format(): + pass +def _multilabel_stat_scores_update(): + pass def _multilabel_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" @@ -209,7 +219,12 @@ def _multilabel_stat_scores_compute( return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) def multilabel_stat_scores(): - pass + if validate_args: + _multilabel_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_classes, top_k, average, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _multilabel_stat_scores_compute(tp, fp, tn, fn, multidim_average) # -------------------------- Old stuff -------------------------- From dee14cf81a988c5bc433adb9e40876afd629b9a9 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 3 Jun 2022 20:27:30 +0200 Subject: [PATCH 13/74] confmat working --- tests/classification/inputs.py | 90 ++++ tests/classification/test_confusion_matrix.py | 441 ++++++++++++------ tests/helpers/testers.py | 2 +- torchmetrics/classification/__init__.py | 7 +- .../classification/confusion_matrix.py | 14 +- torchmetrics/classification/stat_scores.py | 23 +- .../classification/confusion_matrix.py | 50 +- .../functional/classification/stat_scores.py | 39 +- torchmetrics/utilities/checks.py | 2 +- 9 files changed, 477 insertions(+), 191 deletions(-) diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index 635f99957f1..aa328fd14c4 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -14,12 +14,22 @@ from collections import namedtuple import torch +from torch import Tensor from tests.helpers import seed_all from tests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES seed_all(1) + +def _inv_sigmoid(x: Tensor) -> Tensor: + return (x / (1 - x)).log() + + +def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: + return torch.nn.functional.log_softmax(x, dim) + + Input = namedtuple("Input", ["preds", "target"]) _input_binary_prob = Input( @@ -60,6 +70,86 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ) +_binary_cases = ( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), +) + + +_multiclass_cases = ( + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), +) + + +_multilabel_cases = ( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index f45c7c30c29..194be2eb457 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -11,219 +11,366 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from functools import partial import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix - -from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from tests.classification.inputs import _input_multiclass as _input_mcls -from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.classification.inputs import _input_multilabel as _input_mlb -from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torch import Tensor + +from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester -from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix +from torchmetrics.classification.confusion_matrix import ( + BinaryConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) +from torchmetrics.functional.classification.confusion_matrix import ( + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) seed_all(42) -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_int.preds, _input_binary_int.target), - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary_logit.preds, _input_binary_logit.target), - (_input_binary_int_multidim.preds, _input_binary_int_multidim.target), - (_input_binary_prob_multidim.preds, _input_binary_prob_multidim.target), - (_input_binary_logit_multidim.preds, _input_binary_logit_multidim.target), - ] -) +def _inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: + idx = torch.randperm(x.numel()) + x = deepcopy(x) + x.view(-1)[idx[::5]] = -1 + return x + + +def _sk_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize) + + +@pytest.mark.parametrize("input", _binary_cases) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +@pytest.mark.parametrize("ignore_index", [None, -1]) class TestBinaryConfusionMatrix(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - def test_binary_confusion_matrix(self, preds, target, ddp, normalize): + def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=BinaryConfusionMatrix, - sk_metric=_sk_confusion_matrix_binary, + sk_metric=partial(_sk_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, - "normalize": normalize - } + "normalize": normalize, + "ignore_index": ignore_index, + }, ) - def test_confusion_matrix_functional(self, preds, target, ddp, normalize): + def test_binary_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, metric_functional=binary_confusion_matrix, - sk_metric=_sk_confusion_matrix_binary, + sk_metric=partial(_sk_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), metric_args={ "threshold": THRESHOLD, - "normalize": normalize - } + "normalize": normalize, + "ignore_index": ignore_index, + }, ) -# -------------------------- Old stuff -------------------------- - -def _sk_cm_binary_prob(preds, target, normalize=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_binary(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multilabel_prob(preds, target, normalize=None): - sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.numpy() - - cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) - if normalize is not None: - if normalize == "true": - cm = cm / cm.sum(axis=1, keepdims=True) - elif normalize == "pred": - cm = cm / cm.sum(axis=0, keepdims=True) - elif normalize == "all": - cm = cm / cm.sum() - cm[np.isnan(cm)] = 0 - return cm - - -def _sk_cm_multilabel(preds, target, normalize=None): - sk_preds = preds.numpy() - sk_target = target.numpy() - - cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) - if normalize is not None: - if normalize == "true": - cm = cm / cm.sum(axis=1, keepdims=True) - elif normalize == "pred": - cm = cm / cm.sum(axis=0, keepdims=True) - elif normalize == "all": - cm = cm / cm.sum() - cm[np.isnan(cm)] = 0 - return cm - - -def _sk_cm_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - -def _sk_cm_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +def _sk_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multidim_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize) +@pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), - (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), - (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), - (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), - (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), - (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), - (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), - (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False), - ], -) -class TestConfusionMatrix(MetricTester): +@pytest.mark.parametrize("ignore_index", [None, -1]) +class TestMulticlassConfusionMatrix(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_confusion_matrix( - self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step - ): + def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=ConfusionMatrix, - sk_metric=partial(sk_metric, normalize=normalize), - dist_sync_on_step=dist_sync_on_step, + metric_class=MulticlassConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_classes": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) - def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): + def test_multiclass_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, - metric_functional=confusion_matrix, - sk_metric=partial(sk_metric, normalize=normalize), + metric_functional=multiclass_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + + +def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + confmat = [] + for i in range(preds.shape[1]): + p, t = preds[:, i], target[:, i] + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + confmat.append(sk_confusion_matrix(t, p, normalize=normalize)) + return np.stack(confmat, axis=0) + + +@pytest.mark.parametrize("input", _multilabel_cases) +@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +@pytest.mark.parametrize("ignore_index", [None, -1]) +class TestMultilabelConfusionMatrix(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "num_labels": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) - def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): - self.run_differentiability_test( + def test_multilabel_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = _inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( preds=preds, target=target, - metric_module=ConfusionMatrix, - metric_functional=confusion_matrix, + metric_functional=multilabel_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) -def test_warning_on_nan(tmpdir): +def test_warning_on_nan(): preds = torch.randint(3, size=(20,)) target = torch.randint(3, size=(20,)) with pytest.warns( UserWarning, - match=".* nan values found in confusion matrix have been replaced with zeros.", + match=".* NaN values found in confusion matrix have been replaced with zeros.", ): - confusion_matrix(preds, target, num_classes=5, normalize="true") + multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") + + +# -------------------------- Old stuff -------------------------- + +# def _sk_cm_binary_prob(preds, target, normalize=None): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_binary(preds, target, normalize=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multilabel_prob(preds, target, normalize=None): +# sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.numpy() + +# cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) +# if normalize is not None: +# if normalize == "true": +# cm = cm / cm.sum(axis=1, keepdims=True) +# elif normalize == "pred": +# cm = cm / cm.sum(axis=0, keepdims=True) +# elif normalize == "all": +# cm = cm / cm.sum() +# cm[np.isnan(cm)] = 0 +# return cm + + +# def _sk_cm_multilabel(preds, target, normalize=None): +# sk_preds = preds.numpy() +# sk_target = target.numpy() + +# cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) +# if normalize is not None: +# if normalize == "true": +# cm = cm / cm.sum(axis=1, keepdims=True) +# elif normalize == "pred": +# cm = cm / cm.sum(axis=0, keepdims=True) +# elif normalize == "all": +# cm = cm / cm.sum() +# cm[np.isnan(cm)] = 0 +# return cm + + +# def _sk_cm_multiclass_prob(preds, target, normalize=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multiclass(preds, target, normalize=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multidim_multiclass(preds, target, normalize=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes, multilabel", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), +# (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), +# (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), +# (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), +# (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), +# (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), +# (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), +# (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False), +# ], +# ) +# class TestConfusionMatrix(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_confusion_matrix( +# self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step +# ): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=ConfusionMatrix, +# sk_metric=partial(sk_metric, normalize=normalize), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# "normalize": normalize, +# "multilabel": multilabel, +# }, +# ) + +# def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=confusion_matrix, +# sk_metric=partial(sk_metric, normalize=normalize), +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# "normalize": normalize, +# "multilabel": multilabel, +# }, +# ) + +# def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=ConfusionMatrix, +# metric_functional=confusion_matrix, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# "normalize": normalize, +# "multilabel": multilabel, +# }, +# ) + + +# def test_warning_on_nan(tmpdir): +# preds = torch.randint(3, size=(20,)) +# target = torch.randint(3, size=(20,)) + +# with pytest.warns( +# UserWarning, +# match=".* nan values found in confusion matrix have been replaced with zeros.", +# ): +# confusion_matrix(preds, target, num_classes=5, normalize="true") diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index d37e8d4d78a..4ee1e023c76 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -402,7 +402,7 @@ def run_class_metric_test( target: Union[Tensor, List[Dict]], metric_class: Metric, sk_metric: Callable, - dist_sync_on_step: bool, + dist_sync_on_step: bool = False, metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index f826af79d94..2575b9f9ba2 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -21,10 +21,11 @@ from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 from torchmetrics.classification.confusion_matrix import ( # noqa: F401 - ConfusionMatrix, BinaryConfusionMatrix, + ConfusionMatrix, MulticlassConfusionMatrix, - MultilabelConfusionMatrix + MultilabelConfusionMatrix, +) from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 @@ -42,8 +43,8 @@ from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.specificity import Specificity # noqa: F401 from torchmetrics.classification.stat_scores import ( # noqa: F401 - StatScores, BinaryStatScores, MulticlassStatScores, MultilabelStatScores, + StatScores, ) diff --git a/torchmetrics/classification/confusion_matrix.py b/torchmetrics/classification/confusion_matrix.py index 3779b9525f7..05dbfd19edf 100644 --- a/torchmetrics/classification/confusion_matrix.py +++ b/torchmetrics/classification/confusion_matrix.py @@ -17,13 +17,13 @@ from torch import Tensor from torchmetrics.functional.classification.confusion_matrix import ( - _confusion_matrix_compute, - _confusion_matrix_update, _binary_confusion_matrix_arg_validation, _binary_confusion_matrix_compute, _binary_confusion_matrix_format, _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, _multiclass_confusion_matrix_compute, _multiclass_confusion_matrix_format, @@ -33,7 +33,7 @@ _multilabel_confusion_matrix_compute, _multilabel_confusion_matrix_format, _multilabel_confusion_matrix_tensor_validation, - _multilabel_confusion_matrix_update + _multilabel_confusion_matrix_update, ) from torchmetrics.metric import Metric @@ -114,6 +114,7 @@ class MultilabelConfusionMatrix(Metric): def __init__( self, num_labels: int, + threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None, validate_args: bool = True, @@ -121,8 +122,9 @@ def __init__( ) -> None: super().__init__(**kwargs) if validate_args: - _multilabel_confusion_matrix_arg_validation(num_labels, ignore_index, normalize) + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) self.num_labels = num_labels + self.threshold = threshold self.ignore_index = ignore_index self.normalize = normalize self.validate_args = validate_args @@ -132,7 +134,9 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) - preds, target = _multilabel_confusion_matrix_format(preds, target, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) self.confmat += confmat diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 4a37854d22c..74e6d013b5e 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -17,8 +17,6 @@ from torch import Tensor from torchmetrics.functional.classification.stat_scores import ( - _stat_scores_compute, - _stat_scores_update, _binary_stat_scores_arg_validation, _binary_stat_scores_compute, _binary_stat_scores_format, @@ -33,7 +31,9 @@ _multilabel_stat_scores_compute, _multilabel_stat_scores_format, _multilabel_stat_scores_tensor_validation, - _multilabel_stat_scores_update + _multilabel_stat_scores_update, + _stat_scores_compute, + _stat_scores_update, ) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod @@ -61,10 +61,10 @@ def __init__( self.validate_args = validate_args if self.multidim_average == "samplewise": - default = lambda : [ ] + default = lambda: [] dist_reduce_fx = "cat" else: - default = lambda : torch.zeros(1) + default = lambda: torch.zeros(1) dist_reduce_fx = "sum" self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) @@ -117,13 +117,13 @@ def __init__( self.validate_args = validate_args if self.multidim_average == "samplewise" or self.average == "samples": - default = lambda : [ ] + default = lambda: [] dist_reduce_fx = "cat" elif self.average != "micro": - default = lambda : torch.zeros(num_classes) + default = lambda: torch.zeros(num_classes) dist_reduce_fx = "sum" else: - default = lambda : torch.zeros(1) + default = lambda: torch.zeros(1) dist_reduce_fx = "sum" self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) @@ -176,13 +176,13 @@ def __init__( self.validate_args = validate_args if self.multidim_average == "samplewise" or self.average == "samples": - default = lambda : [ ] + default = lambda: [] dist_reduce_fx = "cat" elif self.average != "micro": - default = lambda : torch.zeros(num_labels) + default = lambda: torch.zeros(num_labels) dist_reduce_fx = "sum" else: - default = lambda : torch.zeros(1) + default = lambda: torch.zeros(1) dist_reduce_fx = "sum" self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) @@ -208,6 +208,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore def compute(self) -> Tensor: return _multilabel_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + # -------------------------- Old stuff -------------------------- diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index c9cbbeabde6..7b5b271ca7a 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import Tensor -from torchmetrics.utilities.prints import rank_zero_warn -from torchmetrics.utilities.checks import _input_format_classification, _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.prints import rank_zero_warn def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[str] = None, multilabel: bool = False) -> Tensor: @@ -33,7 +33,7 @@ def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[str] = None, m elif normalize == "pred": confmat = confmat / confmat.sum(axis=1 if multilabel else 0, keepdim=True) elif normalize == "all": - confmat = confmat / confmat.sum(axis=[1, 2] if multilabel else [0, 1]) + confmat = confmat / confmat.sum(axis=[1, 2] if multilabel else [0, 1], keepdim=True) nan_elements = confmat[torch.isnan(confmat)].nelement() if nan_elements != 0: @@ -70,8 +70,8 @@ def _binary_confusion_matrix_tensor_validation( check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) if check: raise RuntimeError( - "Detected the following values in `target`: {unique_values} but expected only" - " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." ) # If preds is label tensor, also check that it only contains [0,1] values @@ -79,7 +79,7 @@ def _binary_confusion_matrix_tensor_validation( unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( - "Detected the following values in `preds`: {unique_values} but expected only" + f"Detected the following values in `preds`: {unique_values} but expected only" " the following values [0,1] since preds is a label tensor." ) @@ -236,10 +236,12 @@ def multiclass_confusion_matrix( def _multilabel_confusion_matrix_arg_validation( - num_labels: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None + num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None ) -> None: if not isinstance(num_labels, int) and num_labels < 2: raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") allowed_normalize = ("true", "pred", "all", "none", None) @@ -250,7 +252,30 @@ def _multilabel_confusion_matrix_arg_validation( def _multilabel_confusion_matrix_tensor_validation( preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None ) -> None: - pass + """Validate tensor input.""" + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) def _multilabel_confusion_matrix_format( @@ -260,11 +285,12 @@ def _multilabel_confusion_matrix_format( if not ((0 <= preds) * (preds <= 1)).all(): preds = preds.sigmoid() preds = preds > threshold - preds = preds.movedim(1, -1).reshape(-1, num_labels) target = target.movedim(1, -1).reshape(-1, num_labels) if ignore_index is not None: + preds = preds.clone() + target = target.clone() # make sure that when we map, it will always result in a negative number that we can filter away idx = target == ignore_index preds[idx] = -4 * num_labels @@ -275,7 +301,7 @@ def _multilabel_confusion_matrix_format( def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() - unique_mapping = unique_mapping[unique_mapping > 0] + unique_mapping = unique_mapping[unique_mapping >= 0] bins = _bincount(unique_mapping, minlength=4 * num_labels) return bins.reshape(num_labels, 2, 2) @@ -469,7 +495,7 @@ def confusion_matrix( "`torchmetrics.functional.binary_confusion_matrix`, `torchmetrics.functional.multiclass_confusion_matrix`" "and `torchmetrics.functional.multilabel_confusion_matrix`. Please upgrade to the version that matches" "your problem (API may have changed). This function will be removed v0.11.", - DeprecationWarning + DeprecationWarning, ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel) return _confusion_matrix_compute(confmat, normalize) diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index fd8549fb6a4..37fe0d9f046 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -16,7 +16,7 @@ import torch from torch import Tensor, tensor -from torchmetrics.utilities.checks import _input_format_classification, _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod from torchmetrics.utilities.prints import rank_zero_warn @@ -114,6 +114,7 @@ def _binary_stat_scores_compute( return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + def binary_stat_scores( preds: Tensor, target: Tensor, @@ -133,7 +134,7 @@ def binary_stat_scores( def _multiclass_stat_scores_arg_validation( num_classes: int, top_k: int = 1, - average: str = 'micro', + average: str = "micro", multidim_average: str = "global", ignore_index: Optional[int] = None, ) -> None: @@ -143,9 +144,7 @@ def _multiclass_stat_scores_arg_validation( raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") allowed_average = ("micro", "macro", "weighted,", "weighted", "samples", "none", None) if average not in allowed_average: - raise ValueError( - f"Expected argument `average` to be one of {allowed_average}, but got {average}" - ) + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") allowed_multidim_average = ("global", "samplewise") if multidim_average not in allowed_multidim_average: raise ValueError( @@ -160,7 +159,7 @@ def _multiclass_stat_scores_tensor_validation( target: Tensor, num_classes: int, top_k: int = 1, - multidim_average: str = 'global', + multidim_average: str = "global", ignore_index: Optional[int] = None, ) -> None: pass @@ -194,23 +193,30 @@ def multiclass_stat_scores( ) -> Tensor: if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation(preds, target, num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation( + preds, target, num_classes, top_k, average, multidim_average, ignore_index + ) preds, target = _multiclass_stat_scores_format(preds, target, ignore_index) tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, multidim_average) return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) + def _multilabel_stat_scores_arg_validation(): pass + def _multilabel_stat_scores_tensor_validation(): pass + def _multilabel_stat_scores_format(): pass + def _multilabel_stat_scores_update(): pass + def _multilabel_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" ) -> Tensor: @@ -218,14 +224,25 @@ def _multilabel_stat_scores_compute( return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) -def multilabel_stat_scores(): + +def multilabel_stat_scores( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: str = "micro", + multidim_average: str = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: if validate_args: - _multilabel_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_classes, top_k, average, multidim_average, ignore_index) + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, average, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _multilabel_stat_scores_compute(tp, fp, tn, fn, multidim_average) + # -------------------------- Old stuff -------------------------- @@ -630,7 +647,7 @@ def stat_scores( "`torchmetrics.functional.binary_stat_scores`, `torchmetrics.functional.multiclass_stat_scores`" "and `torchmetrics.functional.multilabel_stat_scores`. Please upgrade to the version that matches" "your problem (API may have changed). This function will be removed v0.11.", - DeprecationWarning + DeprecationWarning, ) if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") diff --git a/torchmetrics/utilities/checks.py b/torchmetrics/utilities/checks.py index bf853e1226c..6575343ec0c 100644 --- a/torchmetrics/utilities/checks.py +++ b/torchmetrics/utilities/checks.py @@ -31,7 +31,7 @@ def _check_same_shape(preds: Tensor, target: Tensor) -> None: """Check that predictions and target have the same shape, else raise error.""" if preds.shape != target.shape: raise RuntimeError( - "Predictions and targets are expected to have the same shape," " but got {preds.shape} and {target.shape}." + f"Predictions and targets are expected to have the same shape, but got {preds.shape} and {target.shape}." ) From b9bd9dc7471d01f57fd4d438ff28377aea9ba6ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Jun 2022 18:40:36 +0000 Subject: [PATCH 14/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torchmetrics/functional/__init__.py | 4 ++-- torchmetrics/functional/classification/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 67e67e373e0..aee8a90b1aa 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -21,8 +21,8 @@ from torchmetrics.functional.classification.calibration_error import calibration_error from torchmetrics.functional.classification.cohen_kappa import cohen_kappa from torchmetrics.functional.classification.confusion_matrix import ( - confusion_matrix, binary_confusion_matrix, + confusion_matrix, multiclass_confusion_matrix, multilabel_confusion_matrix, ) @@ -43,10 +43,10 @@ from torchmetrics.functional.classification.roc import roc from torchmetrics.functional.classification.specificity import specificity from torchmetrics.functional.classification.stat_scores import ( - stat_scores, binary_stat_scores, multiclass_stat_scores, multilabel_stat_scores, + stat_scores, ) from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index 3eddbe133d3..f1efe75fb5b 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -18,8 +18,8 @@ from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 - confusion_matrix, binary_confusion_matrix, + confusion_matrix, multiclass_confusion_matrix, multilabel_confusion_matrix, ) @@ -40,8 +40,8 @@ from torchmetrics.functional.classification.roc import roc # noqa: F401 from torchmetrics.functional.classification.specificity import specificity # noqa: F401 from torchmetrics.functional.classification.stat_scores import ( # noqa: F401 - stat_scores, binary_stat_scores, multiclass_stat_scores, multilabel_stat_scores, + stat_scores, ) From f1c664d670e8d5769e59c8fe77ed79756df95698 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 5 Jun 2022 13:27:04 +0200 Subject: [PATCH 15/74] working binary stat scores --- tests/classification/test_confusion_matrix.py | 23 +- tests/classification/test_stat_scores.py | 664 ++++++++++-------- tests/helpers/testers.py | 8 + torchmetrics/classification/stat_scores.py | 9 +- .../functional/classification/stat_scores.py | 7 +- 5 files changed, 400 insertions(+), 311 deletions(-) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 194be2eb457..480a8954a97 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from functools import partial import numpy as np @@ -19,11 +18,10 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from torch import Tensor from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, @@ -38,13 +36,6 @@ seed_all(42) -def _inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: - idx = torch.randperm(x.numel()) - x = deepcopy(x) - x.view(-1)[idx[::5]] = -1 - return x - - def _sk_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None): preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -67,7 +58,7 @@ class TestBinaryConfusionMatrix(MetricTester): def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): preds, target = input if ignore_index is not None: - target = _inject_ignore_index(target, ignore_index) + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, @@ -84,7 +75,7 @@ def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): def test_binary_confusion_matrix_functional(self, input, normalize, ignore_index): preds, target = input if ignore_index is not None: - target = _inject_ignore_index(target, ignore_index) + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, @@ -121,7 +112,7 @@ class TestMulticlassConfusionMatrix(MetricTester): def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): preds, target = input if ignore_index is not None: - target = _inject_ignore_index(target, ignore_index) + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, @@ -138,7 +129,7 @@ def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): def test_multiclass_confusion_matrix_functional(self, input, normalize, ignore_index): preds, target = input if ignore_index is not None: - target = _inject_ignore_index(target, ignore_index) + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, @@ -180,7 +171,7 @@ class TestMultilabelConfusionMatrix(MetricTester): def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): preds, target = input if ignore_index is not None: - target = _inject_ignore_index(target, ignore_index) + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, @@ -197,7 +188,7 @@ def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): def test_multilabel_confusion_matrix_functional(self, input, normalize, ignore_index): preds, target = input if ignore_index is not None: - target = _inject_ignore_index(target, ignore_index) + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 6f19e1c148b..67558add06e 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -12,317 +12,403 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional import numpy as np import pytest -import torch -from sklearn.metrics import multilabel_confusion_matrix -from torch import Tensor, tensor - -from tests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob, _input_multiclass -from tests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from tests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from tests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from tests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from tests.classification.inputs import _input_multilabel as _input_mcls -from tests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix + +from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics import StatScores -from torchmetrics.functional import stat_scores -from torchmetrics.utilities.checks import _input_format_classification +from tests.helpers.testers import THRESHOLD, MetricTester, inject_ignore_index +from torchmetrics.classification.stat_scores import BinaryStatScores +from torchmetrics.functional.classification.stat_scores import binary_stat_scores seed_all(42) -def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None): - # todo: `mdmc_reduce` is unused - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - sk_preds, sk_target = preds.numpy(), target.numpy() - - if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: - sk_preds = np.delete(sk_preds, ignore_index, 1) - sk_target = np.delete(sk_target, ignore_index, 1) - - if preds.shape[1] == 1 and reduce == "samples": - sk_target = sk_target.T - sk_preds = sk_preds.T - - sk_stats = multilabel_confusion_matrix( - sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 - ) - - if preds.shape[1] == 1 and reduce != "samples": - sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] +def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() else: - sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] - - if reduce == "micro": - sk_stats = sk_stats.sum(axis=0, keepdims=True) - - sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) - - if reduce == "micro": - sk_stats = sk_stats[0] - - if reduce == "macro" and ignore_index is not None and preds.shape[1]: - sk_stats[ignore_index, :] = -1 - - return sk_stats - - -def _sk_stat_scores_mdim_mcls( - preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold -): - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - - return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold) - if mdmc_reduce == "samplewise": - scores = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold) - - scores.append(np.expand_dims(scores_i, 0)) - - return np.concatenate(scores) - - -@pytest.mark.parametrize( - "reduce, mdmc_reduce, num_classes, inputs, ignore_index", - [ - ["unknown", None, None, _input_binary, None], - ["micro", "unknown", None, _input_binary, None], - ["macro", None, None, _input_binary, None], - ["micro", None, None, _input_mdmc_prob, None], - ["micro", None, None, _input_binary_prob, 0], - ["micro", None, None, _input_mcls_prob, NUM_CLASSES], - ["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES], - ], -) -def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): - """Test a combination of parameters that are invalid and should raise an error. - - This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when - ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` - when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. - """ - with pytest.raises(ValueError): - stat_scores( - inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index - ) - - with pytest.raises(ValueError): - sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) - sts(inputs.preds[0], inputs.target[0]) - - -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) -@pytest.mark.parametrize( - "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold", - [ - (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0), - (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5), - (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5), - (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5), - (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), - (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0), - (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None, 0.0), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - _sk_stat_scores_mdim_mcls, - "samplewise", - NUM_CLASSES, - None, - None, - 0.0, - ), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - _sk_stat_scores_mdim_mcls, - "global", - NUM_CLASSES, - None, - None, - 0.0, - ), - ], -) -class TestStatScores(MetricTester): - # DDP tests temporarily disabled due to hanging issues - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - @pytest.mark.parametrize("dtype", [torch.float, torch.double]) - def test_stat_scores_class( - self, - ddp: bool, - dist_sync_on_step: bool, - dtype: torch.dtype, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if preds.is_floating_point(): - preds = preds.to(dtype) - if target.is_floating_point(): - target = target.to(dtype) + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if multidim_average == "global": + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + tn, fp, fn, tp = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() + return np.array([tp, fp, tn, fn, tp + fp + tn + fn]) + else: + res = [] + for pred, true in zip(preds, target): + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, fn, tp = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() + res.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +@pytest.mark.parametrize("ignore_index", [None, 0, -1]) +@pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) +class TestBinaryStatScores(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=StatScores, - sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - top_k=top_k, - threshold=threshold, - ), - dist_sync_on_step=dist_sync_on_step, - metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, - "ignore_index": ignore_index, - "top_k": top_k, - }, + metric_class=BinaryStatScores, + sk_metric=partial(_sk_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - def test_stat_scores_fn( - self, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") + def test_binary_stat_scores_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=stat_scores, - sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - multiclass=multiclass, - ignore_index=ignore_index, - top_k=top_k, - threshold=threshold, - ), + preds=preds, + target=target, + metric_functional=binary_stat_scores, + sk_metric=partial(_sk_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average), metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "top_k": top_k, + "multidim_average": multidim_average, }, ) - def test_stat_scores_differentiability( - self, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - self.run_differentiability_test( - preds, - target, - metric_module=StatScores, - metric_functional=stat_scores, - metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, - "ignore_index": ignore_index, - "top_k": top_k, - }, - ) +@pytest.mark.parametrize("input", _multiclass_cases) +@pytest.mark.parametrize("ignore_index", [None, 0, -1]) +@pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) +class TestMulitclassStatScores(MetricTester): + pass + + +@pytest.mark.parametrize("input", _multilabel_cases) +@pytest.mark.parametrize("ignore_index", [None, 0, -1]) +@pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) +class TestMultilabelyStatScores(MetricTester): + pass + + +# -------------------------- Old stuff -------------------------- -_mc_k_target = tensor([0, 1, 2]) -_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize( - "k, preds, target, reduce, expected", - [ - (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), - (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), - (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), - (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), - (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), - (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), - (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), - (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), - ], -) -def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): - """A simple test to check that top_k works as expected.""" - - class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) - class_metric.update(preds, target) - - assert torch.equal(class_metric.compute(), expected.T) - assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) +# def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None): +# # todo: `mdmc_reduce` is unused +# preds, target, _ = _input_format_classification( +# preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k +# ) +# sk_preds, sk_target = preds.numpy(), target.numpy() + +# if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: +# sk_preds = np.delete(sk_preds, ignore_index, 1) +# sk_target = np.delete(sk_target, ignore_index, 1) + +# if preds.shape[1] == 1 and reduce == "samples": +# sk_target = sk_target.T +# sk_preds = sk_preds.T + +# sk_stats = multilabel_confusion_matrix( +# sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 +# ) + +# if preds.shape[1] == 1 and reduce != "samples": +# sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] +# else: +# sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + +# if reduce == "micro": +# sk_stats = sk_stats.sum(axis=0, keepdims=True) + +# sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + +# if reduce == "micro": +# sk_stats = sk_stats[0] + +# if reduce == "macro" and ignore_index is not None and preds.shape[1]: +# sk_stats[ignore_index, :] = -1 + +# return sk_stats + + +# def _sk_stat_scores_mdim_mcls( +# preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold +# ): +# preds, target, _ = _input_format_classification( +# preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k +# ) + +# if mdmc_reduce == "global": +# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) +# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + +# return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold) +# if mdmc_reduce == "samplewise": +# scores = [] + +# for i in range(preds.shape[0]): +# pred_i = preds[i, ...].T +# target_i = target[i, ...].T +# scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold) + +# scores.append(np.expand_dims(scores_i, 0)) + +# return np.concatenate(scores) + + +# @pytest.mark.parametrize( +# "reduce, mdmc_reduce, num_classes, inputs, ignore_index", +# [ +# ["unknown", None, None, _input_binary, None], +# ["micro", "unknown", None, _input_binary, None], +# ["macro", None, None, _input_binary, None], +# ["micro", None, None, _input_mdmc_prob, None], +# ["micro", None, None, _input_binary_prob, 0], +# ["micro", None, None, _input_mcls_prob, NUM_CLASSES], +# ["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES], +# ], +# ) +# def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): +# """Test a combination of parameters that are invalid and should raise an error. + +# This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when +# ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` +# when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. +# """ +# with pytest.raises(ValueError): +# stat_scores( +# inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index +# ) + +# with pytest.raises(ValueError): +# sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) +# sts(inputs.preds[0], inputs.target[0]) + + +# @pytest.mark.parametrize("ignore_index", [None, 0]) +# @pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) +# @pytest.mark.parametrize( +# "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold", +# [ +# (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0), +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5), +# (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5), +# (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5), +# (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), +# (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0), +# (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), +# ( +# _input_mdmc.preds, +# _input_mdmc.target, +# _sk_stat_scores_mdim_mcls, +# "samplewise", +# NUM_CLASSES, +# None, +# None, +# 0.0 +# ), +# ( +# _input_mdmc_prob.preds, +# _input_mdmc_prob.target, +# _sk_stat_scores_mdim_mcls, +# "samplewise", +# NUM_CLASSES, +# None, +# None, +# 0.0, +# ), +# (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0), +# ( +# _input_mdmc_prob.preds, +# _input_mdmc_prob.target, +# _sk_stat_scores_mdim_mcls, +# "global", +# NUM_CLASSES, +# None, +# None, +# 0.0, +# ), +# ], +# ) +# class TestStatScores(MetricTester): +# # DDP tests temporarily disabled due to hanging issues +# @pytest.mark.parametrize("ddp", [False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# @pytest.mark.parametrize("dtype", [torch.float, torch.double]) +# def test_stat_scores_class( +# self, +# ddp: bool, +# dist_sync_on_step: bool, +# dtype: torch.dtype, +# sk_fn: Callable, +# preds: Tensor, +# target: Tensor, +# reduce: str, +# mdmc_reduce: Optional[str], +# num_classes: Optional[int], +# multiclass: Optional[bool], +# ignore_index: Optional[int], +# top_k: Optional[int], +# threshold: Optional[float], +# ): +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# if preds.is_floating_point(): +# preds = preds.to(dtype) +# if target.is_floating_point(): +# target = target.to(dtype) + +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=StatScores, +# sk_metric=partial( +# sk_fn, +# reduce=reduce, +# mdmc_reduce=mdmc_reduce, +# num_classes=num_classes, +# multiclass=multiclass, +# ignore_index=ignore_index, +# top_k=top_k, +# threshold=threshold, +# ), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "num_classes": num_classes, +# "reduce": reduce, +# "mdmc_reduce": mdmc_reduce, +# "threshold": threshold, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "top_k": top_k, +# }, +# ) + +# def test_stat_scores_fn( +# self, +# sk_fn: Callable, +# preds: Tensor, +# target: Tensor, +# reduce: str, +# mdmc_reduce: Optional[str], +# num_classes: Optional[int], +# multiclass: Optional[bool], +# ignore_index: Optional[int], +# top_k: Optional[int], +# threshold: Optional[float], +# ): +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=stat_scores, +# sk_metric=partial( +# sk_fn, +# reduce=reduce, +# mdmc_reduce=mdmc_reduce, +# num_classes=num_classes, +# multiclass=multiclass, +# ignore_index=ignore_index, +# top_k=top_k, +# threshold=threshold, +# ), +# metric_args={ +# "num_classes": num_classes, +# "reduce": reduce, +# "mdmc_reduce": mdmc_reduce, +# "threshold": threshold, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "top_k": top_k, +# }, +# ) + +# def test_stat_scores_differentiability( +# self, +# sk_fn: Callable, +# preds: Tensor, +# target: Tensor, +# reduce: str, +# mdmc_reduce: Optional[str], +# num_classes: Optional[int], +# multiclass: Optional[bool], +# ignore_index: Optional[int], +# top_k: Optional[int], +# threshold: Optional[float], +# ): +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# self.run_differentiability_test( +# preds, +# target, +# metric_module=StatScores, +# metric_functional=stat_scores, +# metric_args={ +# "num_classes": num_classes, +# "reduce": reduce, +# "mdmc_reduce": mdmc_reduce, +# "threshold": threshold, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "top_k": top_k, +# }, +# ) + + +# _mc_k_target = tensor([0, 1, 2]) +# _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +# _ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +# _ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +# @pytest.mark.parametrize( +# "k, preds, target, reduce, expected", +# [ +# (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), +# (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), +# (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), +# (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), +# (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), +# (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), +# (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), +# (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), +# ], +# ) +# def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): +# """A simple test to check that top_k works as expected.""" + +# class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) +# class_metric.update(preds, target) + +# assert torch.equal(class_metric.compute(), expected.T) +# assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 4ee1e023c76..5da7cb3a3c9 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -14,6 +14,7 @@ import os import pickle import sys +from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Union @@ -619,3 +620,10 @@ def compute(self): class DummyMetricMultiOutput(DummyMetricSum): def compute(self): return [self.x, self.x] + + +def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: + idx = torch.randperm(x.numel()) + x = deepcopy(x) + x.view(-1)[idx[::5]] = ignore_index + return x diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 74e6d013b5e..56a54e952b5 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -64,7 +64,7 @@ def __init__( default = lambda: [] dist_reduce_fx = "cat" else: - default = lambda: torch.zeros(1) + default = lambda: torch.zeros(1, dtype=torch.long) dist_reduce_fx = "sum" self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) @@ -88,7 +88,11 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.fn += fn def compute(self) -> Tensor: - return _binary_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp + fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp + tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn + fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn + return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) class MulticlassStatScores(Metric): @@ -206,6 +210,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.fn += fn def compute(self) -> Tensor: + return _multilabel_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 37fe0d9f046..1634129d89f 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -88,6 +88,7 @@ def _binary_stat_scores_format( if ignore_index is not None: idx = target == ignore_index + target = target.clone() target[idx] = -1 return preds, target @@ -99,7 +100,7 @@ def _binary_stat_scores_update( multidim_average: str = "global", ) -> Tensor: """""" - sum_dim = 0 if multidim_average == "global" else 1 + sum_dim = [0, 1] if multidim_average == "global" else 1 tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() @@ -110,9 +111,7 @@ def _binary_stat_scores_update( def _binary_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" ) -> Tensor: - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=0 if multidim_average == "global" else 1).squeeze() def binary_stat_scores( From a21fc0f215931290d105f07e5f260b9280dc4b69 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Jun 2022 14:44:35 +0200 Subject: [PATCH 16/74] full testing --- tests/classification/test_confusion_matrix.py | 12 +-- tests/classification/test_stat_scores.py | 54 +++++++++- .../classification/confusion_matrix.py | 6 ++ .../functional/classification/stat_scores.py | 100 ++++++++++++++++-- 4 files changed, 153 insertions(+), 19 deletions(-) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 480a8954a97..906e8551afc 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -47,12 +47,12 @@ def _sk_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None idx = target == ignore_index target = target[~idx] preds = preds[~idx] - return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize) + return sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1], normalize=normalize) @pytest.mark.parametrize("input", _binary_cases) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize("ignore_index", [None, -1]) +@pytest.mark.parametrize("ignore_index", [None, -1, 0]) class TestBinaryConfusionMatrix(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): @@ -101,12 +101,12 @@ def _sk_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index= idx = target == ignore_index target = target[~idx] preds = preds[~idx] - return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize) + return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) @pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize("ignore_index", [None, -1]) +@pytest.mark.parametrize("ignore_index", [None, -1, 0]) class TestMulticlassConfusionMatrix(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): @@ -159,13 +159,13 @@ def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index= idx = t == ignore_index t = t[~idx] p = p[~idx] - confmat.append(sk_confusion_matrix(t, p, normalize=normalize)) + confmat.append(sk_confusion_matrix(t, p, normalize=normalize, labels=[0, 1])) return np.stack(confmat, axis=0) @pytest.mark.parametrize("input", _multilabel_cases) @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize("ignore_index", [None, -1]) +@pytest.mark.parametrize("ignore_index", [None, -1, 0]) class TestMultilabelConfusionMatrix(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 67558add06e..0d6d6847665 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -20,9 +20,9 @@ from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all -from tests.helpers.testers import THRESHOLD, MetricTester, inject_ignore_index +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index from torchmetrics.classification.stat_scores import BinaryStatScores -from torchmetrics.functional.classification.stat_scores import binary_stat_scores +from torchmetrics.functional.classification.stat_scores import binary_stat_scores, multilabel_stat_scores seed_all(42) @@ -101,15 +101,63 @@ def test_binary_stat_scores_functional(self, input, ignore_index, multidim_avera @pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) +@pytest.mark.parametrize("average", ["micro", "macro", "samples"]) +@pytest.mark.parametrize("top_k", [1, 2]) class TestMulitclassStatScores(MetricTester): pass +def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + stat_scores = [] + for i in range(preds.shape[1]): + p, t = preds[:, i], target[:, i] + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() + stat_scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + return np.stack(stat_scores, axis=0) + + @pytest.mark.parametrize("input", _multilabel_cases) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) +@pytest.mark.parametrize("average", ["micro", "macro", "samples"]) class TestMultilabelyStatScores(MetricTester): - pass + def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_stat_scores, + sk_metric=partial( + _sk_stat_scores_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) # -------------------------- Old stuff -------------------------- diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index 7b5b271ca7a..66d7b2a83f5 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -256,6 +256,12 @@ def _multilabel_confusion_matrix_tensor_validation( # Check that they have same shape _check_same_shape(preds, target) + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + # Check that target only contains [0,1] values or value in ignore_index unique_values = torch.unique(target) if ignore_index is None: diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 1634129d89f..384cbee9a9f 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -190,6 +190,9 @@ def multiclass_stat_scores( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + import pdb + + pdb.set_trace() if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation( @@ -200,20 +203,97 @@ def multiclass_stat_scores( return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) -def _multilabel_stat_scores_arg_validation(): - pass +def _multilabel_stat_scores_arg_validation( + num_labels: int, + threshold: float = 0.5, + average: str = "micro", + multidim_average: str = "global", + ignore_index: Optional[int] = None, +) -> None: + if not isinstance(num_labels, int) and num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + allowed_average = ("micro", "macro", "samples") + if average not in allowed_average: + raise ValueError(f"Expected argument `multidim_average` to be one of {allowed_average}, but got {average}") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") -def _multilabel_stat_scores_tensor_validation(): - pass +def _multilabel_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_labels: int, + multidim_average: str, + ignore_index: Optional[int] = None, +): + # Check that they have same shape + _check_same_shape(preds, target) + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) -def _multilabel_stat_scores_format(): - pass + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + "Detected the following values in `target`: {unique_values} but expected only" + " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + if multidim_average != "global" and preds.ndim < 3: + raise ValueError("Expected input to be atleast 3D when multidim_average is set to `samplewise`") -def _multilabel_stat_scores_update(): - pass + +def _multilabel_stat_scores_format( + preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + if preds.is_floating_point(): + if not ((0 <= preds) * (preds <= 1)).all(): + preds = preds.sigmoid() + preds = preds > threshold + preds = preds.movedim(1, -1).reshape(-1, num_labels) + target = target.movedim(1, -1).reshape(-1, num_labels) + + if ignore_index is not None: + idx = target == ignore_index + target = target.clone() + target[idx] = -1 + + return preds, target + + +def _multilabel_stat_scores_update( + preds: Tensor, target: Tensor, multidim_average: str = "global" +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + sum_dim = [0, 1] if multidim_average == "global" else 1 + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn def _multilabel_stat_scores_compute( @@ -236,8 +316,8 @@ def multilabel_stat_scores( ) -> Tensor: if validate_args: _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) - _multilabel_stat_scores_tensor_validation(preds, target, num_labels, average, multidim_average, ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) return _multilabel_stat_scores_compute(tp, fp, tn, fn, multidim_average) From f11c1e7cb1b6514c8828ff4c96c8eb0db13c10ad Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 7 Jun 2022 15:51:12 +0200 Subject: [PATCH 17/74] update --- tests/classification/inputs.py | 2 +- tests/classification/test_stat_scores.py | 65 +++++++- torchmetrics/classification/stat_scores.py | 1 - .../classification/confusion_matrix.py | 4 +- .../functional/classification/stat_scores.py | 145 ++++++++++++++++-- 5 files changed, 192 insertions(+), 25 deletions(-) diff --git a/tests/classification/inputs.py b/tests/classification/inputs.py index aa328fd14c4..48e6d00b04a 100644 --- a/tests/classification/inputs.py +++ b/tests/classification/inputs.py @@ -109,7 +109,7 @@ def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), ), Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), ), Input( diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 0d6d6847665..9adfc9b12a8 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -22,7 +22,11 @@ from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index from torchmetrics.classification.stat_scores import BinaryStatScores -from torchmetrics.functional.classification.stat_scores import binary_stat_scores, multilabel_stat_scores +from torchmetrics.functional.classification.stat_scores import ( + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, +) seed_all(42) @@ -98,13 +102,64 @@ def test_binary_stat_scores_functional(self, input, ignore_index, multidim_avera ) +def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if preds.ndim == target.ndim + 1: + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + res = np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return res * (w / w.sum()).reshape(-1, 1) + elif average is None or average == "none": + return res + + @pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) -@pytest.mark.parametrize("average", ["micro", "macro", "samples"]) -@pytest.mark.parametrize("top_k", [1, 2]) -class TestMulitclassStatScores(MetricTester): - pass +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +class TestMulticlassStatScores(MetricTester): + def test_multiclass_stat_scores_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_stat_scores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index 56a54e952b5..c328e20ff8e 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -210,7 +210,6 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.fn += fn def compute(self) -> Tensor: - return _multilabel_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) diff --git a/torchmetrics/functional/classification/confusion_matrix.py b/torchmetrics/functional/classification/confusion_matrix.py index 66d7b2a83f5..03e1b206b47 100644 --- a/torchmetrics/functional/classification/confusion_matrix.py +++ b/torchmetrics/functional/classification/confusion_matrix.py @@ -193,7 +193,9 @@ def _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, i ) -def _multiclass_confusion_matrix_format(preds, target, ignore_index) -> Tuple[Tensor, Tensor]: +def _multiclass_confusion_matrix_format( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: # Apply argmax if we have one more dimension if preds.ndim == target.ndim + 1: preds = preds.argmax(dim=1) diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 384cbee9a9f..6f98a91a7ed 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -17,6 +17,7 @@ from torch import Tensor, tensor from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.data import _bincount from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod from torchmetrics.utilities.prints import rank_zero_warn @@ -141,6 +142,10 @@ def _multiclass_stat_scores_arg_validation( raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") if not isinstance(top_k, int) and top_k < 1: raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") + if top_k > num_classes: + raise ValueError( + f"Expected argument `top_k` to be smaller or equal to `num_classes` but got {top_k} and {num_classes}" + ) allowed_average = ("micro", "macro", "weighted,", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") @@ -157,24 +162,133 @@ def _multiclass_stat_scores_tensor_validation( preds: Tensor, target: Tensor, num_classes: int, - top_k: int = 1, multidim_average: str = "global", ignore_index: Optional[int] = None, ) -> None: - pass + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + if multidim_average != "global" and preds.ndim < 3: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should " + " atleast 3D when multidim_average is set to `samplewise`" + ) + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + if multidim_average != "global" and preds.ndim < 2: + raise ValueError( + "When `preds` and `target` have the same shape, the shape of `preds` should " + " atleast 2D when multidim_average is set to `samplewise`" + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) -def _multiclass_stat_scores_format(): - pass + unique_values = torch.unique(target) + if ignore_index is None: + check = len(unique_values) > num_classes + else: + check = len(unique_values) > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found" + f"{len(unique_values)} in `target`." + ) + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if len(unique_values) > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {len(unique_values)} in `preds`." + ) -def _multiclass_stat_scores_update(): - pass + +def _multiclass_stat_scores_format( + preds: Tensor, + target: Tensor, + top_k: int = 1, +): + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1 and top_k == 1: + preds = preds.argmax(dim=1) + if top_k != 1: + preds = preds.reshape(*preds.shape[:2], -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + return preds, target + + +def _multiclass_stat_scores_update( + preds: Tensor, + target: Tensor, + num_classes: int, + average: str = "micro", + top_k: int = 1, + multidim_average: str = "global", + ignore_index: Optional[int] = None, +): + if multidim_average == "samplewise": + if top_k > 1: + _, preds = torch.topk(preds, k=top_k, dim=1) + + preds_oh = torch.nn.functional.one_hot(preds, num_classes) + target_oh = torch.nn.functional.one_hot(target, num_classes) + sum_dim = [1] if top_k == 1 else [1, 2] + tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) + fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) + fp = ((target_oh != preds_oh) & (target_oh == 0)).sum(sum_dim) + tn = ((target_oh == preds_oh) & (target_oh == 0)).sum(sum_dim) + return tp, fn, fp, tn + else: + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + confmat = bins.reshape(num_classes, num_classes) + tp = confmat.diag() + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + return tp, fp, tn, fn def _multiclass_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: + res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.float().mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res if multidim_average == "global": return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) @@ -184,23 +298,20 @@ def multiclass_stat_scores( preds: Tensor, target: Tensor, num_classes: int, - top_k: int = 1, average: str = "micro", + top_k: int = 1, multidim_average: str = "global", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - import pdb - - pdb.set_trace() if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) - _multiclass_stat_scores_tensor_validation( - preds, target, num_classes, top_k, average, multidim_average, ignore_index - ) - preds, target = _multiclass_stat_scores_format(preds, target, ignore_index) - tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, multidim_average) - return _multiclass_stat_scores_compute(tp, fp, tn, fn, multidim_average) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, num_classes, average, top_k, multidim_average, ignore_index + ) + return _multiclass_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) def _multilabel_stat_scores_arg_validation( From 5a1495c85381a8dd29e7a1c5d160e092bd3eec63 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Wed, 8 Jun 2022 13:21:08 +0200 Subject: [PATCH 18/74] update --- tests/classification/test_stat_scores.py | 98 +++++++++++++------ tests/helpers/testers.py | 5 +- .../functional/classification/stat_scores.py | 31 +++--- 3 files changed, 94 insertions(+), 40 deletions(-) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 9adfc9b12a8..074ef05c7b8 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -15,12 +15,13 @@ import numpy as np import pytest +import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index +from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, BATCH_SIZE, NUM_BATCHES from torchmetrics.classification.stat_scores import BinaryStatScores from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, @@ -36,12 +37,20 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): preds = preds.view(-1).numpy() target = target.view(-1).numpy() else: + #if preds.shape[0] == BATCH_SIZE * NUM_BATCHES: + # preds = torch.chunk(preds, NUM_BATCHES) + # preds = torch.cat([*preds[1::2], *preds[::2]], 0).numpy() + # target = torch.chunk(target, NUM_BATCHES) + # target = torch.cat([*target[1::2], *target[::2]], 0).numpy() + #else: preds = preds.numpy() target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): if not ((0 < preds) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) + if multidim_average == "global": if ignore_index is not None: idx = target == ignore_index @@ -52,6 +61,8 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): else: res = [] for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() if ignore_index is not None: idx = true == ignore_index true = true[~idx] @@ -65,7 +76,7 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) class TestBinaryStatScores(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ddp", [False, True]) def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): preds, target = input if ignore_index == -1: @@ -103,33 +114,64 @@ def test_binary_stat_scores_functional(self, input, ignore_index, multidim_avera def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): - preds = preds.numpy() - target = target.numpy() if preds.ndim == target.ndim + 1: - preds = np.argmax(preds, axis=1) - preds = preds.flatten() - target = target.flatten() - if ignore_index is not None: - idx = target == ignore_index - target = target[~idx] - preds = preds[~idx] - confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) - tp = np.diag(confmat) - fp = confmat.sum(0) - tp - fn = confmat.sum(1) - tp - tn = confmat.sum() - (fp + fn + tp) - - res = np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) - if average == "micro": - return res.sum(0) - elif average == "macro": - return res.mean(0) - elif average == "weighted": - w = tp + fn - return res * (w / w.sum()).reshape(-1, 1) - elif average is None or average == "none": - return res - + preds = torch.argmax(preds, 1) + if multidim_average == 'global': + preds = preds.numpy().flatten() + target = target.numpy().flatten() + + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + res = np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + else: + preds = preds.numpy() + target = target.numpy() + + res = [ ] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + if average == "micro": + res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1).sum(0)) + elif average == "macro": + res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1).mean(0)) + elif average == "weighted": + w = tp + fn + res.append( + (np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) * (w / w.sum()).reshape(-1, 1)).sum(0) + ) + elif average is None or average == "none": + res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1)) + return np.stack(res, 0) @pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 5da7cb3a3c9..7db8a58b179 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -35,7 +35,7 @@ NUM_PROCESSES = 2 NUM_BATCHES = 4 # Need to be divisible with the number of processes -BATCH_SIZE = 32 +BATCH_SIZE = 3 # NUM_BATCHES = 10 if torch.cuda.is_available() else 4 # BATCH_SIZE = 64 if torch.cuda.is_available() else 32 NUM_CLASSES = 5 @@ -243,6 +243,9 @@ def _class_test( } sk_result = sk_metric(total_preds, total_target, **total_kwargs_update) + print(sk_result) + print(result) + # assert after aggregation if isinstance(sk_result, dict): for key in sk_result.keys(): diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 6f98a91a7ed..0eecccb5949 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -248,17 +248,25 @@ def _multiclass_stat_scores_update( ignore_index: Optional[int] = None, ): if multidim_average == "samplewise": + if ignore_index is not None: + preds = preds.clone() + target = target.clone() + idx = target == ignore_index + preds[idx] = num_classes + target[idx] = num_classes if top_k > 1: _, preds = torch.topk(preds, k=top_k, dim=1) - - preds_oh = torch.nn.functional.one_hot(preds, num_classes) - target_oh = torch.nn.functional.one_hot(target, num_classes) + preds_oh = torch.nn.functional.one_hot(preds, num_classes if ignore_index is None else num_classes+1) + target_oh = torch.nn.functional.one_hot(target, num_classes if ignore_index is None else num_classes+1) + if ignore_index is not None: + preds_oh = preds_oh[...,:-1] + target_oh = target_oh[...,:-1] sum_dim = [1] if top_k == 1 else [1, 2] tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) fp = ((target_oh != preds_oh) & (target_oh == 0)).sum(sum_dim) tn = ((target_oh == preds_oh) & (target_oh == 0)).sum(sum_dim) - return tp, fn, fp, tn + return tp, fp, tn, fn else: preds = preds.flatten() target = target.flatten() @@ -279,19 +287,18 @@ def _multiclass_stat_scores_update( def _multiclass_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: - res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + + res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) + sum_dim = 0 if multidim_average=='global' else 1 if average == "micro": - return res.sum(0) + return res.sum(sum_dim) elif average == "macro": - return res.float().mean(0) + return res.float().mean(sum_dim) elif average == "weighted": w = tp + fn - return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + return (res * (w / w.sum()).reshape(*w.shape, 1)).sum(sum_dim) elif average is None or average == "none": return res - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) def multiclass_stat_scores( @@ -304,6 +311,8 @@ def multiclass_stat_scores( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + import pdb + pdb.set_trace() if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) From f82d807404a8f0c106af1a0e3cdda2ba5ffe554a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jun 2022 11:50:26 +0000 Subject: [PATCH 19/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_stat_scores.py | 25 +++++++++---------- .../functional/classification/stat_scores.py | 11 ++++---- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 074ef05c7b8..e822cebe4ca 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -21,7 +21,7 @@ from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, BATCH_SIZE, NUM_BATCHES +from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index from torchmetrics.classification.stat_scores import BinaryStatScores from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, @@ -37,12 +37,12 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): preds = preds.view(-1).numpy() target = target.view(-1).numpy() else: - #if preds.shape[0] == BATCH_SIZE * NUM_BATCHES: + # if preds.shape[0] == BATCH_SIZE * NUM_BATCHES: # preds = torch.chunk(preds, NUM_BATCHES) # preds = torch.cat([*preds[1::2], *preds[::2]], 0).numpy() # target = torch.chunk(target, NUM_BATCHES) # target = torch.cat([*target[1::2], *target[::2]], 0).numpy() - #else: + # else: preds = preds.numpy() target = target.numpy() @@ -116,10 +116,10 @@ def test_binary_stat_scores_functional(self, input, ignore_index, multidim_avera def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) - if multidim_average == 'global': + if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - + if ignore_index is not None: idx = target == ignore_index target = target[~idx] @@ -140,16 +140,16 @@ def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, av return (res * (w / w.sum()).reshape(-1, 1)).sum(0) elif average is None or average == "none": return res - + else: preds = preds.numpy() target = target.numpy() - - res = [ ] + + res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - + if ignore_index is not None: idx = true == ignore_index true = true[~idx] @@ -159,20 +159,19 @@ def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, av fp = confmat.sum(0) - tp fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) - + if average == "micro": res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1).sum(0)) elif average == "macro": res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1).mean(0)) elif average == "weighted": w = tp + fn - res.append( - (np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) * (w / w.sum()).reshape(-1, 1)).sum(0) - ) + res.append((np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) * (w / w.sum()).reshape(-1, 1)).sum(0)) elif average is None or average == "none": res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1)) return np.stack(res, 0) + @pytest.mark.parametrize("input", _multiclass_cases) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 0eecccb5949..7b5936463ca 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -256,11 +256,11 @@ def _multiclass_stat_scores_update( target[idx] = num_classes if top_k > 1: _, preds = torch.topk(preds, k=top_k, dim=1) - preds_oh = torch.nn.functional.one_hot(preds, num_classes if ignore_index is None else num_classes+1) - target_oh = torch.nn.functional.one_hot(target, num_classes if ignore_index is None else num_classes+1) + preds_oh = torch.nn.functional.one_hot(preds, num_classes if ignore_index is None else num_classes + 1) + target_oh = torch.nn.functional.one_hot(target, num_classes if ignore_index is None else num_classes + 1) if ignore_index is not None: - preds_oh = preds_oh[...,:-1] - target_oh = target_oh[...,:-1] + preds_oh = preds_oh[..., :-1] + target_oh = target_oh[..., :-1] sum_dim = [1] if top_k == 1 else [1, 2] tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) @@ -289,7 +289,7 @@ def _multiclass_stat_scores_compute( ) -> Tensor: res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) - sum_dim = 0 if multidim_average=='global' else 1 + sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": return res.sum(sum_dim) elif average == "macro": @@ -312,6 +312,7 @@ def multiclass_stat_scores( validate_args: bool = True, ) -> Tensor: import pdb + pdb.set_trace() if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) From c4cfef3fd3386efb6a40002b3cee5f07f749afe2 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 10 Jun 2022 13:22:41 +0200 Subject: [PATCH 20/74] add missing tests --- tests/classification/test_stat_scores.py | 56 +++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 9adfc9b12a8..110ec3ac6ee 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -21,7 +21,7 @@ from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index -from torchmetrics.classification.stat_scores import BinaryStatScores +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, multiclass_stat_scores, @@ -136,6 +136,33 @@ def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, av @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) class TestMulticlassStatScores(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassStatScores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + def test_multiclass_stat_scores_functional(self, input, ignore_index, multidim_average, average): preds, target = input if ignore_index == -1: @@ -188,6 +215,33 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "samples"]) class TestMultilabelyStatScores(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelStatScores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_average, average): preds, target = input if ignore_index == -1: From fe0c12d08b69c83f9822c2856510628243af8183 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 13 Jun 2022 11:47:23 +0200 Subject: [PATCH 21/74] update --- tests/classification/test_stat_scores.py | 103 ++++++++++------- tests/helpers/testers.py | 2 +- torchmetrics/classification/stat_scores.py | 104 +++++++++--------- .../functional/classification/stat_scores.py | 55 +++++---- 4 files changed, 149 insertions(+), 115 deletions(-) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 8e72dfdd322..f8bc83a4cc8 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -17,7 +17,7 @@ import pytest import torch from scipy.special import expit as sigmoid -from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import confusion_matrix as sk_confusion_matrix, multilabel_confusion_matrix as sk_multilabel_confusion_matrix from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all @@ -182,7 +182,7 @@ def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average preds, target = input if ignore_index == -1: target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and preds.ndim < 3: + if multidim_average == "samplewise" and target.ndim < 3: pytest.skip("samplewise and non-multidim arrays are not valid") self.run_class_metric_test( @@ -237,51 +237,72 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av if not ((0 < preds) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) - target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) - stat_scores = [] - for i in range(preds.shape[1]): - p, t = preds[:, i], target[:, i] - if ignore_index is not None: - idx = t == ignore_index - t = t[~idx] - p = p[~idx] - tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() - stat_scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) - return np.stack(stat_scores, axis=0) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if multidim_average == 'global': + stat_scores = [] + for i in range(preds.shape[1]): + p, t = preds[:, i].flatten(), target[:, i].flatten() + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() + stat_scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + res = np.stack(stat_scores, axis=0) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + else: + stat_scores = [] + for i in range(preds.shape[0]): + p, t = preds[i], target[i] + if ignore_index is None: + t = t[~idx] + p = p[~idx] + confmat = sk_multilabel_confusion_matrix(y_true=p.T, target=t.T) + tp, fp, fn, tn = confmat[:,0,0], confmat[:,0,1], confmat[:,1,0], confmat[:,1,1] + @pytest.mark.parametrize("input", _multilabel_cases) @pytest.mark.parametrize("ignore_index", [None, 0, -1]) @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("average", ["micro", "macro", "samples"]) -class TestMultilabelyStatScores(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): - preds, target = input - if ignore_index == -1: - target = inject_ignore_index(target, ignore_index) - if multidim_average == "samplewise" and preds.ndim < 3: - pytest.skip("samplewise and non-multidim arrays are not valid") - - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=MultilabelStatScores, - sk_metric=partial( - _sk_stat_scores_multiclass, - ignore_index=ignore_index, - multidim_average=multidim_average, - average=average, - ), - metric_args={ - "ignore_index": ignore_index, - "multidim_average": multidim_average, - "average": average, - "num_classes": NUM_CLASSES, - }, - ) +class TestMultilabelStatScores(MetricTester): + # @pytest.mark.parametrize("ddp", [True, False]) + # def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + # preds, target = input + # if ignore_index == -1: + # target = inject_ignore_index(target, ignore_index) + # if multidim_average == "samplewise" and preds.ndim < 3: + # pytest.skip("samplewise and non-multidim arrays are not valid") + + # self.run_class_metric_test( + # ddp=ddp, + # preds=preds, + # target=target, + # metric_class=MultilabelStatScores, + # sk_metric=partial( + # _sk_stat_scores_multiclass, + # ignore_index=ignore_index, + # multidim_average=multidim_average, + # average=average, + # ), + # metric_args={ + # "num_labels": NUM_CLASSES, + # "threshold": THRESHOLD, + # "ignore_index": ignore_index, + # "multidim_average": multidim_average, + # "average": average, + # }, + # ) def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_average, average): preds, target = input diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 7db8a58b179..7ce5fcffec6 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -35,7 +35,7 @@ NUM_PROCESSES = 2 NUM_BATCHES = 4 # Need to be divisible with the number of processes -BATCH_SIZE = 3 +BATCH_SIZE = 32 # NUM_BATCHES = 10 if torch.cuda.is_available() else 4 # BATCH_SIZE = 64 if torch.cuda.is_available() else 32 NUM_CLASSES = 5 diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index c328e20ff8e..f8c04114504 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -39,7 +39,29 @@ from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -class BinaryStatScores(Metric): +class AbstractStatScores: + # define common functions + def _update_state(self, tp, fp, tn, fn): + if self.multidim_average == "samplewise": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def _final_state(self): + tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp + fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp + tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn + fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn + return tp, fp, tn, fn + + +class BinaryStatScores(Metric, AbstractStatScores): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -76,26 +98,14 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) - if self.multidim_average == "samplewise": - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - else: - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn + self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: - tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp - fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp - tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn - fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn + tp, fp, tn, fn = self._final_state() return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) -class MulticlassStatScores(Metric): +class MulticlassStatScores(Metric, AbstractStatScores): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -110,7 +120,7 @@ def __init__( validate_args: bool = True, **kwargs: Dict[str, Any], ) -> None: - super.__init__(**kwargs) + super().__init__(**kwargs) if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) self.num_classes = num_classes @@ -123,12 +133,10 @@ def __init__( if self.multidim_average == "samplewise" or self.average == "samples": default = lambda: [] dist_reduce_fx = "cat" - elif self.average != "micro": - default = lambda: torch.zeros(num_classes) - dist_reduce_fx = "sum" else: - default = lambda: torch.zeros(1) + default = lambda: torch.zeros(num_classes) dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) @@ -136,25 +144,19 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: - _multiclass_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) - preds, target = _multiclass_stat_scores_format(preds, target, self.threshold, self.ignore_index) - tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, self.multidim_average) - if self.multidim_average == "samplewise" or self.average == "samples": - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - else: - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn + _multiclass_stat_scores_tensor_validation(preds, target, self.num_classes, self.multidim_average, self.ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, self.num_classes, self.average, self.top_k, self.multidim_average, self.ignore_index + ) + self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: - return _multiclass_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + tp, fp, tn, fn = self._final_state() + return _multiclass_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -class MultilabelStatScores(Metric): +class MultilabelStatScores(Metric, AbstractStatScores): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -169,7 +171,7 @@ def __init__( validate_args: bool = True, **kwargs: Dict[str, Any], ) -> None: - super.__init__(**kwargs) + super().__init__(**kwargs) if validate_args: _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) self.num_labels = num_labels @@ -182,12 +184,10 @@ def __init__( if self.multidim_average == "samplewise" or self.average == "samples": default = lambda: [] dist_reduce_fx = "cat" - elif self.average != "micro": - default = lambda: torch.zeros(num_labels) - dist_reduce_fx = "sum" else: - default = lambda: torch.zeros(1) + default = lambda: torch.zeros(num_labels) dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) @@ -195,22 +195,18 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: - _multilabel_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) - preds, target = _multilabel_stat_scores_format(preds, target, self.threshold, self.ignore_index) + _multilabel_stat_scores_tensor_validation( + preds, target, self.num_labels, self.multidim_average, self.ignore_index + ) + preds, target = _multilabel_stat_scores_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, self.multidim_average) - if self.multidim_average == "samplewise" or self.average == "samples": - self.tp.append(tp) - self.fp.append(fp) - self.tn.append(tn) - self.fn.append(fn) - else: - self.tp += tp - self.fp += fp - self.tn += tn - self.fn += fn + self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: - return _multilabel_stat_scores_compute(self.tp, self.fp, self.tn, self.fn, self.multidim_average) + tp, fp, tn, fn = self._final_state() + return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) # -------------------------- Old stuff -------------------------- diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index 7b5936463ca..f7bf7dfe1c6 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -248,7 +248,8 @@ def _multiclass_stat_scores_update( ignore_index: Optional[int] = None, ): if multidim_average == "samplewise": - if ignore_index is not None: + ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None + if ignore_index is not None and not ignore_in: preds = preds.clone() target = target.clone() idx = target == ignore_index @@ -256,11 +257,19 @@ def _multiclass_stat_scores_update( target[idx] = num_classes if top_k > 1: _, preds = torch.topk(preds, k=top_k, dim=1) - preds_oh = torch.nn.functional.one_hot(preds, num_classes if ignore_index is None else num_classes + 1) - target_oh = torch.nn.functional.one_hot(target, num_classes if ignore_index is None else num_classes + 1) + preds_oh = torch.nn.functional.one_hot( + preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) + target_oh = torch.nn.functional.one_hot( + target, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) if ignore_index is not None: - preds_oh = preds_oh[..., :-1] - target_oh = target_oh[..., :-1] + if 0 <= ignore_index <= num_classes - 1: + target_oh[target == ignore_index, :] = -1 + else: + preds_oh = preds_oh[..., :-1] + target_oh = target_oh[..., :-1] + target_oh[target == num_classes, :] = -1 sum_dim = [1] if top_k == 1 else [1, 2] tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) @@ -287,7 +296,7 @@ def _multiclass_stat_scores_update( def _multiclass_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: - + res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": @@ -295,8 +304,11 @@ def _multiclass_stat_scores_compute( elif average == "macro": return res.float().mean(sum_dim) elif average == "weighted": - w = tp + fn - return (res * (w / w.sum()).reshape(*w.shape, 1)).sum(sum_dim) + weight = tp + fn + if multidim_average == "global": + return (res * (weight / weight.sum()).reshape(*weight.shape, 1)).sum(sum_dim) + else: + return (res * (weight / weight.sum(-1, keepdim=True)).reshape(*weight.shape, 1)).sum(sum_dim) elif average is None or average == "none": return res @@ -311,9 +323,6 @@ def multiclass_stat_scores( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: - import pdb - - pdb.set_trace() if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) @@ -395,8 +404,8 @@ def _multilabel_stat_scores_format( if not ((0 <= preds) * (preds <= 1)).all(): preds = preds.sigmoid() preds = preds > threshold - preds = preds.movedim(1, -1).reshape(-1, num_labels) - target = target.movedim(1, -1).reshape(-1, num_labels) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) if ignore_index is not None: idx = target == ignore_index @@ -409,7 +418,7 @@ def _multilabel_stat_scores_format( def _multilabel_stat_scores_update( preds: Tensor, target: Tensor, multidim_average: str = "global" ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - sum_dim = [0, 1] if multidim_average == "global" else 1 + sum_dim = [0, -1] if multidim_average == "global" else [-1] tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() @@ -418,11 +427,19 @@ def _multilabel_stat_scores_update( def _multilabel_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = 'micro', multidim_average: str = "global" ) -> Tensor: - if multidim_average == "global": - return torch.cat([tp, fp, tn, fn, tp + fp + tn + fn], dim=0) - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=1) + res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) + sum_dim = 0 if multidim_average == "global" else 1 + if average == "micro": + return res.sum(sum_dim) + elif average == "macro": + return res.float().mean(sum_dim) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(*w.shape, 1)).sum(sum_dim) + elif average is None or average == "none": + return res def multilabel_stat_scores( @@ -440,7 +457,7 @@ def multilabel_stat_scores( _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) - return _multilabel_stat_scores_compute(tp, fp, tn, fn, multidim_average) + return _multilabel_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) # -------------------------- Old stuff -------------------------- From 51851395bba82107d3b481bc371c0029a9aab41a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Jun 2022 11:15:10 +0000 Subject: [PATCH 22/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_stat_scores.py | 8 ++++---- torchmetrics/classification/stat_scores.py | 6 ++++-- torchmetrics/functional/classification/stat_scores.py | 4 ++-- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index f8bc83a4cc8..16d495b0682 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -17,7 +17,8 @@ import pytest import torch from scipy.special import expit as sigmoid -from sklearn.metrics import confusion_matrix as sk_confusion_matrix, multilabel_confusion_matrix as sk_multilabel_confusion_matrix +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all @@ -239,7 +240,7 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av preds = (preds >= THRESHOLD).astype(np.uint8) preds = preds.reshape(*preds.shape[:2], -1) target = target.reshape(*target.shape[:2], -1) - if multidim_average == 'global': + if multidim_average == "global": stat_scores = [] for i in range(preds.shape[1]): p, t = preds[:, i].flatten(), target[:, i].flatten() @@ -267,8 +268,7 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av t = t[~idx] p = p[~idx] confmat = sk_multilabel_confusion_matrix(y_true=p.T, target=t.T) - tp, fp, fn, tn = confmat[:,0,0], confmat[:,0,1], confmat[:,1,0], confmat[:,1,1] - + tp, fp, fn, tn = confmat[:, 0, 0], confmat[:, 0, 1], confmat[:, 1, 0], confmat[:, 1, 1] @pytest.mark.parametrize("input", _multilabel_cases) diff --git a/torchmetrics/classification/stat_scores.py b/torchmetrics/classification/stat_scores.py index f8c04114504..7ada569eb7d 100644 --- a/torchmetrics/classification/stat_scores.py +++ b/torchmetrics/classification/stat_scores.py @@ -59,7 +59,7 @@ def _final_state(self): tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn return tp, fp, tn, fn - + class BinaryStatScores(Metric, AbstractStatScores): is_differentiable: bool = False @@ -144,7 +144,9 @@ def __init__( def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: - _multiclass_stat_scores_tensor_validation(preds, target, self.num_classes, self.multidim_average, self.ignore_index) + _multiclass_stat_scores_tensor_validation( + preds, target, self.num_classes, self.multidim_average, self.ignore_index + ) preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) tp, fp, tn, fn = _multiclass_stat_scores_update( preds, target, self.num_classes, self.average, self.top_k, self.multidim_average, self.ignore_index diff --git a/torchmetrics/functional/classification/stat_scores.py b/torchmetrics/functional/classification/stat_scores.py index f7bf7dfe1c6..41f60a5aba2 100644 --- a/torchmetrics/functional/classification/stat_scores.py +++ b/torchmetrics/functional/classification/stat_scores.py @@ -296,7 +296,7 @@ def _multiclass_stat_scores_update( def _multiclass_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: - + res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": @@ -427,7 +427,7 @@ def _multilabel_stat_scores_update( def _multilabel_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = 'micro', multidim_average: str = "global" + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 From 9bd7e7666657a308df1bb64863d474ab5da92815 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Jun 2022 11:18:55 +0000 Subject: [PATCH 23/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/unittests/classification/test_confusion_matrix.py | 2 +- test/unittests/classification/test_stat_scores.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/unittests/classification/test_confusion_matrix.py b/test/unittests/classification/test_confusion_matrix.py index db5e26011eb..5862e6eeb71 100644 --- a/test/unittests/classification/test_confusion_matrix.py +++ b/test/unittests/classification/test_confusion_matrix.py @@ -19,10 +19,10 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix - from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index + from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, diff --git a/test/unittests/classification/test_stat_scores.py b/test/unittests/classification/test_stat_scores.py index 16d495b0682..cc65dc91766 100644 --- a/test/unittests/classification/test_stat_scores.py +++ b/test/unittests/classification/test_stat_scores.py @@ -19,10 +19,10 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix - from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index + from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, From cb5eabf81b7b959dd7f2e4420d74e1a224f9d2ea Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 14 Jun 2022 12:50:40 +0200 Subject: [PATCH 24/74] multilabel stat scores --- .../classification/confusion_matrix.py | 6 +- .../classification/stat_scores.py | 6 +- .../functional/classification/stat_scores.py | 6 +- .../classification/test_confusion_matrix.py | 134 ++++++++++- .../classification/test_stat_scores.py | 225 ++++++++++++++---- test/unittests/helpers/testers.py | 21 +- 6 files changed, 325 insertions(+), 73 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index bb1a3032fbc..832ad5b87e8 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -49,7 +49,7 @@ def __init__( ignore_index: Optional[int] = None, normalize: Optional[str] = None, validate_args: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: @@ -83,7 +83,7 @@ def __init__( ignore_index: Optional[int] = None, normalize: Optional[str] = None, validate_args: bool = True, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: @@ -118,7 +118,7 @@ def __init__( ignore_index: Optional[int] = None, normalize: Optional[str] = None, validate_args: bool = True, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index b694b03efea..79eeaf39a7a 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -72,7 +72,7 @@ def __init__( multidim_average: str = "global", ignore_index: Optional[int] = None, validate_args: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: @@ -118,7 +118,7 @@ def __init__( multidim_average: str = "global", ignore_index: Optional[int] = None, validate_args: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: @@ -171,7 +171,7 @@ def __init__( multidim_average: str = "global", ignore_index: Optional[int] = None, validate_args: bool = True, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 41f60a5aba2..e7ce6c50351 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -146,7 +146,7 @@ def _multiclass_stat_scores_arg_validation( raise ValueError( f"Expected argument `top_k` to be smaller or equal to `num_classes` but got {top_k} and {num_classes}" ) - allowed_average = ("micro", "macro", "weighted,", "weighted", "samples", "none", None) + allowed_average = ("micro", "macro", "weighted", "none", None) if average not in allowed_average: raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") allowed_multidim_average = ("global", "samplewise") @@ -344,9 +344,9 @@ def _multilabel_stat_scores_arg_validation( raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") if not isinstance(threshold, float): raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") - allowed_average = ("micro", "macro", "samples") + allowed_average = ("micro", "macro", "weighted", "none", None) if average not in allowed_average: - raise ValueError(f"Expected argument `multidim_average` to be one of {allowed_average}, but got {average}") + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") allowed_multidim_average = ("global", "samplewise") if multidim_average not in allowed_multidim_average: raise ValueError( diff --git a/test/unittests/classification/test_confusion_matrix.py b/test/unittests/classification/test_confusion_matrix.py index db5e26011eb..1cfcad7a271 100644 --- a/test/unittests/classification/test_confusion_matrix.py +++ b/test/unittests/classification/test_confusion_matrix.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Dict import numpy as np import pytest @@ -20,9 +19,6 @@ from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index from torchmetrics.classification.confusion_matrix import ( BinaryConfusionMatrix, MulticlassConfusionMatrix, @@ -33,6 +29,9 @@ multiclass_confusion_matrix, multilabel_confusion_matrix, ) +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) @@ -52,9 +51,9 @@ def _sk_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None @pytest.mark.parametrize("input", _binary_cases) -@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize("ignore_index", [None, -1, 0]) class TestBinaryConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): preds, target = input @@ -73,6 +72,8 @@ def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): }, ) + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_binary_confusion_matrix_functional(self, input, normalize, ignore_index): preds, target = input if ignore_index is not None: @@ -89,6 +90,43 @@ def test_binary_confusion_matrix_functional(self, input, normalize, ignore_index }, ) + def test_binary_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_half_cpu(self, input, dtype): + preds, target = input + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + def _sk_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index=None): preds = preds.numpy() @@ -106,9 +144,9 @@ def _sk_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index= @pytest.mark.parametrize("input", _multiclass_cases) -@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize("ignore_index", [None, -1, 0]) class TestMulticlassConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): preds, target = input @@ -127,6 +165,8 @@ def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): }, ) + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_multiclass_confusion_matrix_functional(self, input, normalize, ignore_index): preds, target = input if ignore_index is not None: @@ -143,6 +183,41 @@ def test_multiclass_confusion_matrix_functional(self, input, normalize, ignore_i }, ) + def test_multiclass_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_half_cpu(self, input, dtype): + preds, target = input + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None): preds = preds.numpy() @@ -165,9 +240,9 @@ def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index= @pytest.mark.parametrize("input", _multilabel_cases) -@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize("ignore_index", [None, -1, 0]) class TestMultilabelConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): preds, target = input @@ -186,6 +261,8 @@ def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): }, ) + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) def test_multilabel_confusion_matrix_functional(self, input, normalize, ignore_index): preds, target = input if ignore_index is not None: @@ -202,6 +279,43 @@ def test_multilabel_confusion_matrix_functional(self, input, normalize, ignore_i }, ) + def test_multilabel_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_confusion_matrix_half_cpu(self, input, dtype): + preds, target = input + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_confusion_matrix_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + def test_warning_on_nan(): preds = torch.randint(3, size=(20,)) diff --git a/test/unittests/classification/test_stat_scores.py b/test/unittests/classification/test_stat_scores.py index 16d495b0682..30c402660dc 100644 --- a/test/unittests/classification/test_stat_scores.py +++ b/test/unittests/classification/test_stat_scores.py @@ -18,17 +18,16 @@ import torch from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix -from tests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases -from tests.helpers import seed_all -from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, multiclass_stat_scores, multilabel_stat_scores, ) +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers import seed_all +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) @@ -74,9 +73,9 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): @pytest.mark.parametrize("input", _binary_cases) -@pytest.mark.parametrize("ignore_index", [None, 0, -1]) -@pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) class TestBinaryStatScores(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) @pytest.mark.parametrize("ddp", [False, True]) def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): preds, target = input @@ -94,6 +93,8 @@ def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) def test_binary_stat_scores_functional(self, input, ignore_index, multidim_average): preds, target = input if ignore_index == -1: @@ -113,6 +114,43 @@ def test_binary_stat_scores_functional(self, input, ignore_index, multidim_avera }, ) + def test_binary_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_stat_scores_half_cpu(self, input, dtype): + preds, target = input + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_stat_scores_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): if preds.ndim == target.ndim + 1: @@ -174,10 +212,10 @@ def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, av @pytest.mark.parametrize("input", _multiclass_cases) -@pytest.mark.parametrize("ignore_index", [None, 0, -1]) -@pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) -@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) class TestMulticlassStatScores(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) @pytest.mark.parametrize("ddp", [True, False]) def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average, average): preds, target = input @@ -205,6 +243,9 @@ def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average }, ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multiclass_stat_scores_functional(self, input, ignore_index, multidim_average, average): preds, target = input if ignore_index == -1: @@ -230,6 +271,43 @@ def test_multiclass_stat_scores_functional(self, input, ignore_index, multidim_a }, ) + def test_multiclass_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_stat_scores_half_cpu(self, input, dtype): + preds, target = input + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_stat_scores_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() @@ -251,59 +329,77 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() stat_scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) res = np.stack(stat_scores, axis=0) + if average == "micro": return res.sum(0) elif average == "macro": return res.mean(0) elif average == "weighted": - w = tp + fn + w = res[:, 0] + res[:, 3] return (res * (w / w.sum()).reshape(-1, 1)).sum(0) elif average is None or average == "none": return res else: stat_scores = [] for i in range(preds.shape[0]): - p, t = preds[i], target[i] - if ignore_index is None: - t = t[~idx] - p = p[~idx] - confmat = sk_multilabel_confusion_matrix(y_true=p.T, target=t.T) - tp, fp, fn, tn = confmat[:, 0, 0], confmat[:, 0, 1], confmat[:, 1, 0], confmat[:, 1, 1] + scores = [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + stat_scores.append(np.stack(scores, 1)) + res = np.stack(stat_scores, 0) + if average == "micro": + return res.sum(-1) + elif average == "macro": + return res.mean(-1) + elif average == "weighted": + w = res[:, 0, :] + res[:, 3, :] + return (res * (w / w.sum())[:, np.newaxis]).sum(-1) + elif average is None or average == "none": + return np.moveaxis(res, 1, -1) @pytest.mark.parametrize("input", _multilabel_cases) -@pytest.mark.parametrize("ignore_index", [None, 0, -1]) -@pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) -@pytest.mark.parametrize("average", ["micro", "macro", "samples"]) class TestMultilabelStatScores(MetricTester): - # @pytest.mark.parametrize("ddp", [True, False]) - # def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): - # preds, target = input - # if ignore_index == -1: - # target = inject_ignore_index(target, ignore_index) - # if multidim_average == "samplewise" and preds.ndim < 3: - # pytest.skip("samplewise and non-multidim arrays are not valid") - - # self.run_class_metric_test( - # ddp=ddp, - # preds=preds, - # target=target, - # metric_class=MultilabelStatScores, - # sk_metric=partial( - # _sk_stat_scores_multiclass, - # ignore_index=ignore_index, - # multidim_average=multidim_average, - # average=average, - # ), - # metric_args={ - # "num_labels": NUM_CLASSES, - # "threshold": THRESHOLD, - # "ignore_index": ignore_index, - # "multidim_average": multidim_average, - # "average": average, - # }, - # ) + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelStatScores, + sk_metric=partial( + _sk_stat_scores_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_average, average): preds, target = input if ignore_index == -1: @@ -330,6 +426,43 @@ def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_a }, ) + def test_multilabel_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_stat_scores_half_cpu(self, input, dtype): + preds, target = input + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_stat_scores_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + # -------------------------- Old stuff -------------------------- diff --git a/test/unittests/helpers/testers.py b/test/unittests/helpers/testers.py index 7ce5fcffec6..c237e9054d7 100644 --- a/test/unittests/helpers/testers.py +++ b/test/unittests/helpers/testers.py @@ -304,12 +304,13 @@ def _functional_test( _assert_allclose(tm_result, sk_result, atol=atol) -def _assert_half_support( +def _assert_dtype_support( metric_module: Optional[Metric], metric_functional: Optional[Callable], preds: Tensor, target: Tensor, device: str = "cpu", + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if an metric can be used with half precision tensors. @@ -323,10 +324,10 @@ def _assert_half_support( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ - y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) - y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + y_hat = preds[0].to(dtype=dtype, device=device) if preds[0].is_floating_point() else preds[0].to(device) + y = target[0].to(dtype=dtype, device=device) if target[0].is_floating_point() else target[0].to(device) kwargs_update = { - k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v + k: (v[0].to(dtype=dtype) if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items() } if metric_module is not None: @@ -486,6 +487,7 @@ def run_precision_test_cpu( metric_module: Optional[Metric] = None, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if a metric can be used with half precision tensors on cpu @@ -499,12 +501,13 @@ def run_precision_test_cpu( target when running update on the metric. """ metric_args = metric_args or {} - _assert_half_support( + _assert_dtype_support( metric_module(**metric_args) if metric_module is not None else None, - metric_functional, + partial(metric_functional, **metric_args), preds, target, device="cpu", + dtype=dtype, **kwargs_update, ) @@ -515,6 +518,7 @@ def run_precision_test_gpu( metric_module: Optional[Metric] = None, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if a metric can be used with half precision tensors on gpu @@ -528,12 +532,13 @@ def run_precision_test_gpu( target when running update on the metric. """ metric_args = metric_args or {} - _assert_half_support( + _assert_dtype_support( metric_module(**metric_args) if metric_module is not None else None, - metric_functional, + partial(metric_functional, **metric_args), preds, target, device="cuda", + dtype=dtype, **kwargs_update, ) From 44df55566de45c2fee863f8827b1165af8b7d098 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 14 Jun 2022 13:01:43 +0200 Subject: [PATCH 25/74] disable old testing --- setup.cfg | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/setup.cfg b/setup.cfg index 19420c48db5..fd1e9182886 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,29 @@ doctest_plus = enabled addopts = --strict --color=yes + --ignore=test/unittests/classification/test_accuracy.py + --ignore=test/unittests/classification/test_auc.py + --ignore=test/unittests/classification/test_auroc.py + --ignore=test/unittests/classification/test_average_precision.py + --ignore=test/unittests/classification/test_binned_precision_recall.py + --ignore=test/unittests/classification/test_calibration_error.py + --ignore=test/unittests/classification/test_cohen_kappa.py + #--ignore=test/unittests/classification/test_confusion_matrix.py + --ignore=test/unittests/classification/test_dice.py + --ignore=test/unittests/classification/test_f_beta.py + --ignore=test/unittests/classification/test_hamming_distance.py + --ignore=test/unittests/classification/test_hinge.py + #--ignore=test/unittests/classification/test_inputs.py + --ignore=test/unittests/classification/test_jaccard.py + --ignore=test/unittests/classification/test_kl_divergence.py + --ignore=test/unittests/classification/test_matthews_corrcoef.py + --ignore=test/unittests/classification/test_precision_recall_curve.py + --ignore=test/unittests/classification/test_precision_recall.py + --ignore=test/unittests/classification/test_ranking.py + --ignore=test/unittests/classification/test_roc.py + --ignore=test/unittests/classification/test_specificity.py + #--ignore=test/unittests/classification/test_stat_scores.py + doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS From 5ba6ef3e67da568ac8e65318968009264b87a773 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 14 Jun 2022 15:50:57 +0200 Subject: [PATCH 26/74] more testing --- .../classification/stat_scores.py | 2 +- .../functional/classification/stat_scores.py | 27 +++++----- .../classification/test_stat_scores.py | 50 +++++++++++++------ 3 files changed, 48 insertions(+), 31 deletions(-) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 79eeaf39a7a..bb7eef2a8e3 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -149,7 +149,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore ) preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) tp, fp, tn, fn = _multiclass_stat_scores_update( - preds, target, self.num_classes, self.average, self.top_k, self.multidim_average, self.ignore_index + preds, target, self.num_classes, self.top_k, self.multidim_average, self.ignore_index ) self._update_state(tp, fp, tn, fn) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index e7ce6c50351..57b244d2b69 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -17,7 +17,7 @@ from torch import Tensor, tensor from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification -from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.data import _bincount, select_topk from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod from torchmetrics.utilities.prints import rank_zero_warn @@ -112,7 +112,7 @@ def _binary_stat_scores_update( def _binary_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" ) -> Tensor: - return torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=0 if multidim_average == "global" else 1).squeeze() + return torch.stack([tp, fp, tn, fn, tp + fn], dim=0 if multidim_average == "global" else 1).squeeze() def binary_stat_scores( @@ -242,12 +242,11 @@ def _multiclass_stat_scores_update( preds: Tensor, target: Tensor, num_classes: int, - average: str = "micro", top_k: int = 1, multidim_average: str = "global", ignore_index: Optional[int] = None, ): - if multidim_average == "samplewise": + if multidim_average == "samplewise" or top_k != 1: ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None if ignore_index is not None and not ignore_in: preds = preds.clone() @@ -255,11 +254,13 @@ def _multiclass_stat_scores_update( idx = target == ignore_index preds[idx] = num_classes target[idx] = num_classes + if top_k > 1: - _, preds = torch.topk(preds, k=top_k, dim=1) - preds_oh = torch.nn.functional.one_hot( - preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes - ) + preds_oh = select_topk(preds, topk=top_k, dim=1).movedim(1, -1) + else: + preds_oh = torch.nn.functional.one_hot( + preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) target_oh = torch.nn.functional.one_hot( target, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes ) @@ -270,7 +271,7 @@ def _multiclass_stat_scores_update( preds_oh = preds_oh[..., :-1] target_oh = target_oh[..., :-1] target_oh[target == num_classes, :] = -1 - sum_dim = [1] if top_k == 1 else [1, 2] + sum_dim = [0, 1] if multidim_average == "global" else [1] tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) fp = ((target_oh != preds_oh) & (target_oh == 0)).sum(sum_dim) @@ -297,7 +298,7 @@ def _multiclass_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: - res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) + res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": return res.sum(sum_dim) @@ -327,9 +328,7 @@ def multiclass_stat_scores( _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) preds, target = _multiclass_stat_scores_format(preds, target, top_k) - tp, fp, tn, fn = _multiclass_stat_scores_update( - preds, target, num_classes, average, top_k, multidim_average, ignore_index - ) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) return _multiclass_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) @@ -429,7 +428,7 @@ def _multilabel_stat_scores_update( def _multilabel_stat_scores_compute( tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" ) -> Tensor: - res = torch.stack([tp, fp, tn, fn, tp + fp + tn + fn], dim=-1) + res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": return res.sum(sum_dim) diff --git a/test/unittests/classification/test_stat_scores.py b/test/unittests/classification/test_stat_scores.py index 30c402660dc..e12f7f1c77c 100644 --- a/test/unittests/classification/test_stat_scores.py +++ b/test/unittests/classification/test_stat_scores.py @@ -37,12 +37,6 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): preds = preds.view(-1).numpy() target = target.view(-1).numpy() else: - # if preds.shape[0] == BATCH_SIZE * NUM_BATCHES: - # preds = torch.chunk(preds, NUM_BATCHES) - # preds = torch.cat([*preds[1::2], *preds[::2]], 0).numpy() - # target = torch.chunk(target, NUM_BATCHES) - # target = torch.cat([*target[1::2], *target[::2]], 0).numpy() - # else: preds = preds.numpy() target = target.numpy() @@ -57,7 +51,7 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): target = target[~idx] preds = preds[~idx] tn, fp, fn, tp = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() - return np.array([tp, fp, tn, fn, tp + fp + tn + fn]) + return np.array([tp, fp, tn, fn, tp + fn]) else: res = [] for pred, true in zip(preds, target): @@ -68,7 +62,7 @@ def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): true = true[~idx] pred = pred[~idx] tn, fp, fn, tp = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() - res.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + res.append(np.array([tp, fp, tn, fn, tp + fn])) return np.stack(res) @@ -169,7 +163,7 @@ def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, av fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) - res = np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) + res = np.stack([tp, fp, tn, fn, tp + fn], 1) if average == "micro": return res.sum(0) elif average == "macro": @@ -198,16 +192,16 @@ def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, av fp = confmat.sum(0) - tp fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) - + r = np.stack([tp, fp, tn, fn, tp + fn], 1) if average == "micro": - res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1).sum(0)) + res.append(r.sum(0)) elif average == "macro": - res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1).mean(0)) + res.append(r.mean(0)) elif average == "weighted": w = tp + fn - res.append((np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1) * (w / w.sum()).reshape(-1, 1)).sum(0)) + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) elif average is None or average == "none": - res.append(np.stack([tp, fp, tn, fn, tp + fp + tn + fn], 1)) + res.append(r) return np.stack(res, 0) @@ -309,6 +303,30 @@ def test_multiclass_stat_scores_half_gpu(self, input, dtype): ) +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) + + +@pytest.mark.parametrize( + "k, preds, target, average, expected", + [ + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), + (1, _mc_k_preds, _mc_k_target, None, torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, None, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), + ], +) +def test_top_k_multiclass(k, preds, target, average, expected): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassStatScores(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + assert torch.allclose(class_metric.compute().long(), expected.T) + assert torch.allclose( + multiclass_stat_scores(preds, target, top_k=k, average=average, num_classes=3).long(), expected.T + ) + + def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() @@ -327,7 +345,7 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av t = t[~idx] p = p[~idx] tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() - stat_scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + stat_scores.append(np.array([tp, fp, tn, fn, tp + fn])) res = np.stack(stat_scores, axis=0) if average == "micro": @@ -350,7 +368,7 @@ def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, av true = true[~idx] pred = pred[~idx] tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() - scores.append(np.array([tp, fp, tn, fn, tp + fp + tn + fn])) + scores.append(np.array([tp, fp, tn, fn, tp + fn])) stat_scores.append(np.stack(scores, 1)) res = np.stack(stat_scores, 0) if average == "micro": From 5baec905af8f1672813468de1066d3e47bf7e592 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Wed, 15 Jun 2022 13:02:02 +0200 Subject: [PATCH 27/74] flaky tests --- test/unittests/classification/test_stat_scores.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/unittests/classification/test_stat_scores.py b/test/unittests/classification/test_stat_scores.py index e12f7f1c77c..328717220ee 100644 --- a/test/unittests/classification/test_stat_scores.py +++ b/test/unittests/classification/test_stat_scores.py @@ -77,6 +77,8 @@ def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): target = inject_ignore_index(target, ignore_index) if multidim_average == "samplewise" and preds.ndim < 3: pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, @@ -217,6 +219,8 @@ def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average target = inject_ignore_index(target, ignore_index) if multidim_average == "samplewise" and target.ndim < 3: pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, @@ -394,6 +398,8 @@ def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average target = inject_ignore_index(target, ignore_index) if multidim_average == "samplewise" and preds.ndim < 4: pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, From 12fed05e2f4b4f39d1319a3be9ab51e45d372a34 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 20 Jun 2022 14:37:29 +0200 Subject: [PATCH 28/74] changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7db6036c871..30f1cfbf5eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Classification refactor ( + [#1054](https://github.com/Lightning-AI/metrics/pull/1054), +) + - From fa3adc0a434f2ffabc3e4253e73df8c29f382bc5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 20 Jun 2022 14:47:29 +0200 Subject: [PATCH 29/74] refactor --- setup.cfg | 1 + .../classification/stat_scores.py | 49 ++++++------------- 2 files changed, 17 insertions(+), 33 deletions(-) diff --git a/setup.cfg b/setup.cfg index fd1e9182886..bee2ab8b6f6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,7 @@ doctest_plus = enabled addopts = --strict --color=yes + # TODO: remove when refactor is done --ignore=test/unittests/classification/test_accuracy.py --ignore=test/unittests/classification/test_auc.py --ignore=test/unittests/classification/test_auroc.py diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index bb7eef2a8e3..9b42e68d762 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -41,7 +41,19 @@ class AbstractStatScores: # define common functions - def _update_state(self, tp, fp, tn, fn): + def _create_state(self, size: int, multidim_average: str) -> None: + if multidim_average == "samplewise": + default = lambda: [] + dist_reduce_fx = "cat" + else: + default = lambda: torch.zeros(size, dtype=torch.long) + dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + + def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: if self.multidim_average == "samplewise": self.tp.append(tp) self.fp.append(fp) @@ -82,16 +94,7 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args - if self.multidim_average == "samplewise": - default = lambda: [] - dist_reduce_fx = "cat" - else: - default = lambda: torch.zeros(1, dtype=torch.long) - dist_reduce_fx = "sum" - self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + self._create_state(1, multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: @@ -130,17 +133,7 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args - if self.multidim_average == "samplewise" or self.average == "samples": - default = lambda: [] - dist_reduce_fx = "cat" - else: - default = lambda: torch.zeros(num_classes) - dist_reduce_fx = "sum" - - self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + self._create_state(num_classes, multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: @@ -183,17 +176,7 @@ def __init__( self.ignore_index = ignore_index self.validate_args = validate_args - if self.multidim_average == "samplewise" or self.average == "samples": - default = lambda: [] - dist_reduce_fx = "cat" - else: - default = lambda: torch.zeros(num_labels) - dist_reduce_fx = "sum" - - self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) - self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + self._create_state(num_labels, multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore if self.validate_args: From db6a01a1d0239520f68ece9e6e2146533509c2db Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Jun 2022 10:10:13 +0200 Subject: [PATCH 30/74] class interface --- .../classification/cohen_kappa.py | 22 +++++ src/torchmetrics/classification/jaccard.py | 84 ++++++++++++++++++- .../classification/matthews_corrcoef.py | 22 +++++ 3 files changed, 127 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 7146ac14480..3cfb344f3fe 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -16,10 +16,32 @@ import torch from torch import Tensor +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import _cohen_kappa_compute, _cohen_kappa_update from torchmetrics.metric import Metric +class BinaryCohenKappa(BinaryConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + +class MulticlassCohenKappa(MulticlassConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + +class MultilabelCohenKappa(MultilabelConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + +# -------------------------- Old stuff -------------------------- + + class CohenKappa(Metric): r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement. It is defined as diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index d088f6e0702..6a5755423f8 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -16,8 +16,90 @@ import torch from torch import Tensor +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat +from torchmetrics.functional.classification.jaccard import ( + _binary_jaccard_index_compute, + _binary_jaccard_index_validate_args, + _jaccard_from_confmat, + _multiclass_jaccard_index_compute, + _multilabel_jaccard_index_compute, +) + + +class BinaryJaccardIndex(BinaryConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if validate_args: + _binary_jaccard_index_validate_args(threshold, ignore_index) + super().__init__(threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=False, **kwargs) + + def compute(self) -> Tensor: + return _binary_jaccard_index_compute( + self.confmat, + ) + + +class MulticlassJaccardIndex(MulticlassConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs + ) + + def compute(self) -> Tensor: + return _multiclass_jaccard_index_compute( + self.confmat, + ) + + +class MultilabelJaccardIndex(MultilabelConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_labels=num_labels, + threshold=threshold, + ignore_index=ignore_index, + normalize=None, + validate_args=validate_args, + **kwargs, + ) + + def compute(self) -> Tensor: + return _multilabel_jaccard_index_compute( + self.confmat, + ) + + +# -------------------------- Old stuff -------------------------- class JaccardIndex(ConfusionMatrix): diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index e16778099fe..242fcbce5f7 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -16,6 +16,7 @@ import torch from torch import Tensor +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import ( _matthews_corrcoef_compute, _matthews_corrcoef_update, @@ -23,6 +24,27 @@ from torchmetrics.metric import Metric +class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + +class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + +class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + + +# -------------------------- Old stuff -------------------------- + + class MatthewsCorrCoef(Metric): r"""Calculates `Matthews correlation coefficient`_ that measures the general correlation or quality of a classification. From 55d52f68afaa6cd5db9d164b95b947dd7661d63d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Jun 2022 10:38:41 +0200 Subject: [PATCH 31/74] fixes --- src/torchmetrics/classification/stat_scores.py | 10 +++++----- .../functional/classification/confusion_matrix.py | 14 +++++++++----- .../functional/classification/stat_scores.py | 15 +++++++++------ src/torchmetrics/utilities/data.py | 4 +++- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 9b42e68d762..40032e3b41c 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -39,7 +39,7 @@ from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -class AbstractStatScores: +class AbstractStatScores(Metric): # define common functions def _create_state(self, size: int, multidim_average: str) -> None: if multidim_average == "samplewise": @@ -65,7 +65,7 @@ def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: self.tn += tn self.fn += fn - def _final_state(self): + def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn @@ -73,7 +73,7 @@ def _final_state(self): return tp, fp, tn, fn -class BinaryStatScores(Metric, AbstractStatScores): +class BinaryStatScores(AbstractStatScores): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -108,7 +108,7 @@ def compute(self) -> Tensor: return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) -class MulticlassStatScores(Metric, AbstractStatScores): +class MulticlassStatScores(AbstractStatScores): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -151,7 +151,7 @@ def compute(self) -> Tensor: return _multiclass_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -class MultilabelStatScores(Metric, AbstractStatScores): +class MultilabelStatScores(AbstractStatScores): is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 03e1b206b47..ec8b0b5d450 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -56,7 +56,7 @@ def _binary_confusion_matrix_arg_validation( def _binary_confusion_matrix_tensor_validation( - preds: Tensor, target: Tensor, ignore_index: Optional[int] = bool + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None ) -> None: """Validate tensor input.""" # Check that they have same shape @@ -99,7 +99,7 @@ def _binary_confusion_matrix_format( target = target[idx] if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): + if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid preds = preds.sigmoid() preds = preds > threshold @@ -135,7 +135,9 @@ def binary_confusion_matrix( return _binary_confusion_matrix_compute(confmat, normalize) -def _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) -> None: +def _multiclass_confusion_matrix_arg_validation( + num_classes: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None +) -> None: if not isinstance(num_classes, int) and num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") if ignore_index is not None and not isinstance(ignore_index, int): @@ -145,7 +147,9 @@ def _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, norma raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") -def _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) -> None: +def _multiclass_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: """Validate tensor input.""" if preds.ndim == target.ndim + 1: if not preds.is_floating_point(): @@ -290,7 +294,7 @@ def _multilabel_confusion_matrix_format( preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None ) -> Tuple[Tensor, Tensor]: if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): + if not torch.all((0 <= preds) * (preds <= 1)): preds = preds.sigmoid() preds = preds > threshold preds = preds.movedim(1, -1).reshape(-1, num_labels) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 57b244d2b69..e31649f30db 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -40,7 +40,10 @@ def _binary_stat_scores_arg_validation( def _binary_stat_scores_tensor_validation( - preds: Tensor, target: Tensor, multidim_average: str = "global", ignore_index: Optional[int] = bool + preds: Tensor, + target: Tensor, + multidim_average: str = "global", + ignore_index: Optional[int] = None, ) -> None: """Validate tensor input.""" # Check that they have same shape @@ -79,7 +82,7 @@ def _binary_stat_scores_format( ) -> Tuple[Tensor, Tensor]: """Convert all input to label format.""" if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): + if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid preds = preds.sigmoid() preds = preds > threshold @@ -226,7 +229,7 @@ def _multiclass_stat_scores_format( preds: Tensor, target: Tensor, top_k: int = 1, -): +) -> Tuple[Tensor, Tensor]: # Apply argmax if we have one more dimension if preds.ndim == target.ndim + 1 and top_k == 1: preds = preds.argmax(dim=1) @@ -245,7 +248,7 @@ def _multiclass_stat_scores_update( top_k: int = 1, multidim_average: str = "global", ignore_index: Optional[int] = None, -): +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: if multidim_average == "samplewise" or top_k != 1: ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None if ignore_index is not None and not ignore_in: @@ -361,7 +364,7 @@ def _multilabel_stat_scores_tensor_validation( num_labels: int, multidim_average: str, ignore_index: Optional[int] = None, -): +) -> None: # Check that they have same shape _check_same_shape(preds, target) @@ -400,7 +403,7 @@ def _multilabel_stat_scores_format( preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None ) -> Tuple[Tensor, Tensor]: if preds.is_floating_point(): - if not ((0 <= preds) * (preds <= 1)).all(): + if not torch.all((0 <= preds) * (preds <= 1)): preds = preds.sigmoid() preds = preds > threshold preds = preds.reshape(*preds.shape[:2], -1) diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index 84827d7876f..c9a07e95c3b 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -241,7 +241,7 @@ def _squeeze_if_scalar(data: Any) -> Any: return apply_to_collection(data, Tensor, _squeeze_scalar_element_tensor) -def _bincount(x: Tensor, minlength: int) -> Tensor: +def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: """``torch.bincount`` currently does not support deterministic mode on GPU. This implementation fallback to a for-loop counting occurrences in that case. @@ -253,6 +253,8 @@ def _bincount(x: Tensor, minlength: int) -> Tensor: Returns: Number of occurrences for each unique element in x """ + if minlength is None: + minlength = len(torch.unique(x)) if deterministic(): output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): From 196953cd7cbf28ca64a9de74db070c5ed0a68a2b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Jun 2022 10:43:46 +0200 Subject: [PATCH 32/74] typing --- src/torchmetrics/classification/stat_scores.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 40032e3b41c..c5395227742 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -43,10 +43,10 @@ class AbstractStatScores(Metric): # define common functions def _create_state(self, size: int, multidim_average: str) -> None: if multidim_average == "samplewise": - default = lambda: [] + default: Callable[[], list] = lambda: [] dist_reduce_fx = "cat" else: - default = lambda: torch.zeros(size, dtype=torch.long) + default: Callable[[], Tensor] = lambda: torch.zeros(size, dtype=torch.long) dist_reduce_fx = "sum" self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) From c76b20432646367d48b03cb30695cc886b11ea88 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Jun 2022 10:46:27 +0200 Subject: [PATCH 33/74] update --- setup.cfg | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/setup.cfg b/setup.cfg index 5ef894c9367..f8083dd1856 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,30 +8,6 @@ doctest_plus = enabled addopts = --strict --color=yes - # TODO: remove when refactor is done - --ignore=test/unittests/classification/test_accuracy.py - --ignore=test/unittests/classification/test_auc.py - --ignore=test/unittests/classification/test_auroc.py - --ignore=test/unittests/classification/test_average_precision.py - --ignore=test/unittests/classification/test_binned_precision_recall.py - --ignore=test/unittests/classification/test_calibration_error.py - --ignore=test/unittests/classification/test_cohen_kappa.py - #--ignore=test/unittests/classification/test_confusion_matrix.py - --ignore=test/unittests/classification/test_dice.py - --ignore=test/unittests/classification/test_f_beta.py - --ignore=test/unittests/classification/test_hamming_distance.py - --ignore=test/unittests/classification/test_hinge.py - #--ignore=test/unittests/classification/test_inputs.py - --ignore=test/unittests/classification/test_jaccard.py - --ignore=test/unittests/classification/test_kl_divergence.py - --ignore=test/unittests/classification/test_matthews_corrcoef.py - --ignore=test/unittests/classification/test_precision_recall_curve.py - --ignore=test/unittests/classification/test_precision_recall.py - --ignore=test/unittests/classification/test_ranking.py - --ignore=test/unittests/classification/test_roc.py - --ignore=test/unittests/classification/test_specificity.py - #--ignore=test/unittests/classification/test_stat_scores.py - doctest_optionflags = NORMALIZE_WHITESPACE ELLIPSIS From f0e5d160061fd1fe827caab2afc4c670d84878d5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 21 Jun 2022 15:56:59 +0200 Subject: [PATCH 34/74] fix tests --- .../functional/classification/average_precision.py | 2 +- tests/unittests/helpers/testers.py | 3 --- tests/unittests/retrieval/helpers.py | 4 ++-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index db256cd00d0..54ba2bee60a 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -101,7 +101,7 @@ def _average_precision_compute( if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() else: - weights = _bincount(target, minlength=num_classes).float() + weights = _bincount(target, minlength=2 if num_classes == 1 else num_classes).float() weights = weights / torch.sum(weights) else: weights = None diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index c237e9054d7..a87e420cd87 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -243,9 +243,6 @@ def _class_test( } sk_result = sk_metric(total_preds, total_target, **total_kwargs_update) - print(sk_result) - print(result) - # assert after aggregation if isinstance(sk_result, dict): for key in sk_result.keys(): diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 2561d0fb5c7..4c41479df80 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -485,7 +485,7 @@ def run_precision_test_cpu( metric_module: Metric, metric_functional: Callable, ): - def metric_functional_ignore_indexes(preds, target, indexes): + def metric_functional_ignore_indexes(preds, target, indexes, empty_target_action): return metric_functional(preds, target) super().run_precision_test_cpu( @@ -508,7 +508,7 @@ def run_precision_test_gpu( if not torch.cuda.is_available(): pytest.skip("Test requires GPU") - def metric_functional_ignore_indexes(preds, target, indexes): + def metric_functional_ignore_indexes(preds, target, indexes, empty_target_action): return metric_functional(preds, target) super().run_precision_test_gpu( From 3e5e8db19e6656d605931c623da02707285b8426 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Jun 2022 13:46:36 +0200 Subject: [PATCH 35/74] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> --- src/torchmetrics/classification/confusion_matrix.py | 4 ++-- src/torchmetrics/classification/stat_scores.py | 10 +++++----- .../functional/classification/average_precision.py | 2 +- .../functional/classification/confusion_matrix.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 832ad5b87e8..4ae455abb6c 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -81,7 +81,7 @@ def __init__( self, num_classes: int, ignore_index: Optional[int] = None, - normalize: Optional[str] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, validate_args: bool = True, **kwargs: Any, ) -> None: @@ -116,7 +116,7 @@ def __init__( num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None, - normalize: Optional[str] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, validate_args: bool = True, **kwargs: Any, ) -> None: diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index c5395227742..4142cffb431 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -81,7 +81,7 @@ class BinaryStatScores(AbstractStatScores): def __init__( self, threshold: float = 0.5, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -117,8 +117,8 @@ def __init__( self, num_classes: int, top_k: int = 1, - average: str = "micro", - multidim_average: str = "global", + average: Literal["micro", "macro", "samples"] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, @@ -160,8 +160,8 @@ def __init__( self, num_labels: int, threshold: float = 0.5, - average: str = "micro", - multidim_average: str = "global", + average: Literal["micro", "macro", "samples"] = "micro", + multidim_average: Literal["global", "samplewise"] = "global",, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 54ba2bee60a..0522f93fb9e 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -101,7 +101,7 @@ def _average_precision_compute( if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() else: - weights = _bincount(target, minlength=2 if num_classes == 1 else num_classes).float() + weights = _bincount(target, minlength=max(num_classes, 2)).float() weights = weights / torch.sum(weights) else: weights = None diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index ec8b0b5d450..a08a34246b2 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -22,7 +22,7 @@ from torchmetrics.utilities.prints import rank_zero_warn -def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[str] = None, multilabel: bool = False) -> Tensor: +def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, multilabel: bool = False) -> Tensor: allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") From 887ee26bed959cb9b4d0e588ba31b276c16ee618 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Jun 2022 11:47:12 +0000 Subject: [PATCH 36/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../functional/classification/confusion_matrix.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index a08a34246b2..c538b90cd6e 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -22,7 +22,9 @@ from torchmetrics.utilities.prints import rank_zero_warn -def _confusion_matrix_reduce(confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, multilabel: bool = False) -> Tensor: +def _confusion_matrix_reduce( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, multilabel: bool = False +) -> Tensor: allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") From 2096dd4bc3ef26faaa2210e878ff6451a54e7978 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Jun 2022 13:47:51 +0200 Subject: [PATCH 37/74] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- .../classification/confusion_matrix.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index c538b90cd6e..ce0dcd0c32c 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -31,14 +31,14 @@ def _confusion_matrix_reduce( if normalize is not None and normalize != "none": confmat = confmat.float() if not confmat.is_floating_point() else confmat if normalize == "true": - confmat = confmat / confmat.sum(axis=2 if multilabel else 1, keepdim=True) + confmat = confmat / confmat.sum(axis=-1, keepdim=True) elif normalize == "pred": - confmat = confmat / confmat.sum(axis=1 if multilabel else 0, keepdim=True) + confmat = confmat / confmat.sum(axis=-2, keepdim=True) elif normalize == "all": - confmat = confmat / confmat.sum(axis=[1, 2] if multilabel else [0, 1], keepdim=True) + confmat = confmat / confmat.sum(axis=[-2, -1], keepdim=True) nan_elements = confmat[torch.isnan(confmat)].nelement() - if nan_elements != 0: + if nan_elements: confmat[torch.isnan(confmat)] = 0 rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.") return confmat @@ -47,7 +47,11 @@ def _confusion_matrix_reduce( def _binary_confusion_matrix_arg_validation( threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None ) -> None: - """Validate non tensor input.""" + """Validate non tensor input. + - ``threshold`` has to be a float + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ if not isinstance(threshold, float): raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") if ignore_index is not None and not isinstance(ignore_index, int): @@ -60,11 +64,14 @@ def _binary_confusion_matrix_arg_validation( def _binary_confusion_matrix_tensor_validation( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None ) -> None: - """Validate tensor input.""" + """Validate tensor input. + - tensors have to be of same shape + - all values that are not ignored have to be {0, 1} + """ # Check that they have same shape _check_same_shape(preds, target) - # Check that target only contains [0,1] values or value in ignore_index + # Check that target only contains {0,1} values or value in ignore_index unique_values = torch.unique(target) if ignore_index is None: check = torch.any((unique_values != 0) & (unique_values != 1)) @@ -76,7 +83,7 @@ def _binary_confusion_matrix_tensor_validation( f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." ) - # If preds is label tensor, also check that it only contains [0,1] values + # If preds is label tensor, also check that it only contains {0,1} values if not preds.is_floating_point(): unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): @@ -92,7 +99,9 @@ def _binary_confusion_matrix_format( threshold: float = 0.5, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - """Convert all input to label format.""" + """Convert all input to label format. + Specifically for targets the sigmoid is applied and the results are thresholded. + """ preds = preds.flatten() target = target.flatten() if ignore_index is not None: From 58dc23aefcb4f0c47683ddcf694c2e5efbee4dca Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Jun 2022 11:48:44 +0000 Subject: [PATCH 38/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../functional/classification/confusion_matrix.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index ce0dcd0c32c..60874a74ddf 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -48,9 +48,10 @@ def _binary_confusion_matrix_arg_validation( threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None ) -> None: """Validate non tensor input. - - ``threshold`` has to be a float - - ``ignore_index`` has to be None or int - - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + + - ``threshold`` has to be a float + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None """ if not isinstance(threshold, float): raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") @@ -65,8 +66,9 @@ def _binary_confusion_matrix_tensor_validation( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None ) -> None: """Validate tensor input. - - tensors have to be of same shape - - all values that are not ignored have to be {0, 1} + + - tensors have to be of same shape + - all values that are not ignored have to be {0, 1} """ # Check that they have same shape _check_same_shape(preds, target) @@ -100,7 +102,8 @@ def _binary_confusion_matrix_format( ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. - Specifically for targets the sigmoid is applied and the results are thresholded. + + Specifically for targets the sigmoid is applied and the results are thresholded. """ preds = preds.flatten() target = target.flatten() From c8d82d3b3c37c09646f5a37bd629d5583ffcaf5a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Jun 2022 13:51:48 +0200 Subject: [PATCH 39/74] Apply suggestions from code review Co-authored-by: Daniel Stancl <46073029+stancld@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- .../functional/classification/stat_scores.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index e31649f30db..9fcc39fc22e 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -57,8 +57,8 @@ def _binary_stat_scores_tensor_validation( check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) if check: raise RuntimeError( - "Detected the following values in `target`: {unique_values} but expected only" - " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." ) # If preds is label tensor, also check that it only contains [0,1] values @@ -66,8 +66,8 @@ def _binary_stat_scores_tensor_validation( unique_values = torch.unique(preds) if torch.any((unique_values != 0) & (unique_values != 1)): raise RuntimeError( - "Detected the following values in `preds`: {unique_values} but expected only" - " the following values [0,1] since preds is a label tensor." + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since `preds` is a label tensor." ) if multidim_average != "global" and preds.ndim < 2: @@ -80,7 +80,7 @@ def _binary_stat_scores_format( threshold: float = 0.5, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - """Convert all input to label format.""" + """"""Brings the prediction and target tensors to a unified format (Flattened and class indices).""" if preds.is_floating_point(): if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid @@ -102,7 +102,7 @@ def _binary_stat_scores_update( preds: Tensor, target: Tensor, multidim_average: str = "global", -) -> Tensor: +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """""" sum_dim = [0, 1] if multidim_average == "global" else 1 tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() @@ -122,7 +122,7 @@ def binary_stat_scores( preds: Tensor, target: Tensor, threshold: float = 0.5, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -204,16 +204,16 @@ def _multiclass_stat_scores_tensor_validation( " and `preds` should be (N, C, ...)." ) - unique_values = torch.unique(target) + num_unique_values = len(torch.unique(target)) if ignore_index is None: - check = len(unique_values) > num_classes + check = num_unique_values > num_classes else: - check = len(unique_values) > num_classes + 1 + check = num_unique_values > num_classes + 1 if check: raise RuntimeError( "Detected more unique values in `target` than `num_classes`. Expected only " f"{num_classes if ignore_index is None else num_classes + 1} but found" - f"{len(unique_values)} in `target`." + f"{num_unique_values} in `target`." ) if not preds.is_floating_point(): @@ -321,9 +321,9 @@ def multiclass_stat_scores( preds: Tensor, target: Tensor, num_classes: int, - average: str = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", top_k: int = 1, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: @@ -449,8 +449,8 @@ def multilabel_stat_scores( target: Tensor, num_labels: int, threshold: float = 0.5, - average: str = "micro", - multidim_average: str = "global", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: From f2d870c45c82d11ee326cbb13fdeb2937ae4a595 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Jun 2022 16:56:13 +0200 Subject: [PATCH 40/74] Update src/torchmetrics/functional/classification/confusion_matrix.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/torchmetrics/functional/classification/confusion_matrix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 60874a74ddf..2fb9d400ec0 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -236,6 +236,7 @@ def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_class def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: + """Reduces the confusion matrix to it's final form. Normalization technique can be chosen by ``normalize``.""" return _confusion_matrix_reduce(confmat, normalize, multilabel=False) From 8dec6e4ee79c14797b2816821e250c7f0382d699 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Jun 2022 14:56:48 +0000 Subject: [PATCH 41/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../functional/classification/confusion_matrix.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 2fb9d400ec0..040ba9f63a0 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -236,7 +236,10 @@ def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_class def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - """Reduces the confusion matrix to it's final form. Normalization technique can be chosen by ``normalize``.""" + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ return _confusion_matrix_reduce(confmat, normalize, multilabel=False) From 6c39518be7d72ef3f670e284ce64ad6a4ba7167c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 28 Jun 2022 17:10:00 +0200 Subject: [PATCH 42/74] missing literal --- src/torchmetrics/classification/confusion_matrix.py | 2 +- src/torchmetrics/classification/stat_scores.py | 4 ++-- .../functional/classification/confusion_matrix.py | 10 +++++----- .../functional/classification/stat_scores.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 4ae455abb6c..02cd9dcf399 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Literal, Optional import torch from torch import Tensor diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 4142cffb431..df2b2e3227e 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Literal, Optional, Tuple import torch from torch import Tensor @@ -161,7 +161,7 @@ def __init__( num_labels: int, threshold: float = 0.5, average: Literal["micro", "macro", "samples"] = "micro", - multidim_average: Literal["global", "samplewise"] = "global",, + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 040ba9f63a0..4c1da4025b6 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple import torch from torch import Tensor @@ -23,7 +23,7 @@ def _confusion_matrix_reduce( - confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None, multilabel: bool = False + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None ) -> Tensor: allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: @@ -130,7 +130,7 @@ def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: def _binary_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: """Calculate final confusion matrix.""" - return _confusion_matrix_reduce(confmat, normalize, multilabel=False) + return _confusion_matrix_reduce(confmat, normalize) def binary_confusion_matrix( @@ -240,7 +240,7 @@ def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[st Normalization technique can be chosen by ``normalize``. """ - return _confusion_matrix_reduce(confmat, normalize, multilabel=False) + return _confusion_matrix_reduce(confmat, normalize) def multiclass_confusion_matrix( @@ -337,7 +337,7 @@ def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_label def _multilabel_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - return _confusion_matrix_reduce(confmat, normalize, multilabel=True) + return _confusion_matrix_reduce(confmat, normalize) def multilabel_confusion_matrix( diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 9fcc39fc22e..9a8ca40cdca 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import torch from torch import Tensor, tensor @@ -80,7 +80,7 @@ def _binary_stat_scores_format( threshold: float = 0.5, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - """"""Brings the prediction and target tensors to a unified format (Flattened and class indices).""" + """Brings the prediction and target tensors to a unified format (Flattened and class indices).""" if preds.is_floating_point(): if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid From 05db489a2821044254ee4d7606e143df85fdf2d1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Jun 2022 11:22:53 +0200 Subject: [PATCH 43/74] add docstring to functional confusion matrix --- .../classification/confusion_matrix.py | 263 ++++++++++++++++-- .../functional/classification/stat_scores.py | 2 +- 2 files changed, 235 insertions(+), 30 deletions(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 4c1da4025b6..3ccde466346 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -25,6 +25,18 @@ def _confusion_matrix_reduce( confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None ) -> Tensor: + """Reduce an un-normalized confusion matrix + Args: + confmat: un-normalized confusion matrix + normalize: normalization method. + - `"true"` will divide by the sum of the column dimension. + - `"pred"` will divide by the sum of the row dimension. + - `"all"` will divide by the sum of the full matrix + - `"none"` or `None` will apply no reduction + + Returns: + Normalized confusion matrix + """ allowed_normalize = ("true", "pred", "all", "none", None) if normalize not in allowed_normalize: raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") @@ -45,16 +57,18 @@ def _confusion_matrix_reduce( def _binary_confusion_matrix_arg_validation( - threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ) -> None: """Validate non tensor input. - - ``threshold`` has to be a float + - ``threshold`` has to be a float in the [0,1] range - ``ignore_index`` has to be None or int - ``normalize`` has to be "true" | "pred" | "all" | "none" | None """ - if not isinstance(threshold, float): - raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + if not isinstance(threshold, float) and not (0 <= threshold <= 1): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") allowed_normalize = ("true", "pred", "all", "none", None) @@ -68,7 +82,8 @@ def _binary_confusion_matrix_tensor_validation( """Validate tensor input. - tensors have to be of same shape - - all values that are not ignored have to be {0, 1} + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} """ # Check that they have same shape _check_same_shape(preds, target) @@ -103,7 +118,9 @@ def _binary_confusion_matrix_format( ) -> Tuple[Tensor, Tensor]: """Convert all input to label format. - Specifically for targets the sigmoid is applied and the results are thresholded. + - Remove all datapoints that should be ignored + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards """ preds = preds.flatten() target = target.flatten() @@ -122,14 +139,19 @@ def _binary_confusion_matrix_format( def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: - """Calculate confusion matrix on current input.""" + """Computes the bins to update the confusion matrix with.""" unique_mapping = (target * 2 + preds).to(torch.long) bins = _bincount(unique_mapping, minlength=4) return bins.reshape(2, 2) -def _binary_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: - """Calculate final confusion matrix.""" +def _binary_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ return _confusion_matrix_reduce(confmat, normalize) @@ -138,9 +160,49 @@ def binary_confusion_matrix( target: Tensor, threshold: float = 0.5, ignore_index: Optional[int] = None, - normalize: Optional[str] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, validate_args: bool = True, ) -> Tensor: + r""" + Computes the `confusion matrix`_ for binary tasks. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_confusion_matrix(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_confusion_matrix(preds, target) + tensor([[2, 0], + [1, 1]]) + """ if validate_args: _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) @@ -150,9 +212,17 @@ def binary_confusion_matrix( def _multiclass_confusion_matrix_arg_validation( - num_classes: int, ignore_index: Optional[int] = None, normalize: Optional[str] = None + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ) -> None: - if not isinstance(num_classes, int) and num_classes < 2: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") if ignore_index is not None and not isinstance(ignore_index, int): raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") @@ -164,7 +234,14 @@ def _multiclass_confusion_matrix_arg_validation( def _multiclass_confusion_matrix_tensor_validation( preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None ) -> None: - """Validate tensor input.""" + """Validate tensor input. + + - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - if preds and target have same number of dims, then all dimensions should match + - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1} + - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1} + """ if preds.ndim == target.ndim + 1: if not preds.is_floating_point(): raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") @@ -190,30 +267,35 @@ def _multiclass_confusion_matrix_tensor_validation( " and `preds` should be (N, C, ...)." ) - unique_values = torch.unique(target) + num_unique_values = len(torch.unique(target)) if ignore_index is None: - check = len(unique_values) > num_classes + check = num_unique_values > num_classes else: - check = len(unique_values) > num_classes + 1 + check = num_unique_values > num_classes + 1 if check: raise RuntimeError( "Detected more unique values in `target` than `num_classes`. Expected only " - f"{num_classes if ignore_index is None else num_classes + 1} but found" - f"{len(unique_values)} in `target`." + f"{num_classes if ignore_index is None else num_classes + 1} but found " + f"{num_unique_values} in `target`." ) if not preds.is_floating_point(): - unique_values = torch.unique(preds) - if len(unique_values) > num_classes: + num_unique_values = len(torch.unique(preds)) + if num_unique_values > num_classes: raise RuntimeError( "Detected more unique values in `preds` than `num_classes`. Expected only " - f"{num_classes} but found {len(unique_values)} in `preds`." + f"{num_classes} but found {num_unique_values} in `preds`." ) def _multiclass_confusion_matrix_format( preds: Tensor, target: Tensor, ignore_index: Optional[int] = None ) -> Tuple[Tensor, Tensor]: + """ "Convert all input to label format. + + - Applies argmax if preds have one more dimension than target + - Remove all datapoints that should be ignored + """ # Apply argmax if we have one more dimension if preds.ndim == target.ndim + 1: preds = preds.argmax(dim=1) @@ -230,12 +312,15 @@ def _multiclass_confusion_matrix_format( def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor: + """Computes the bins to update the confusion matrix with.""" unique_mapping = (target * num_classes + preds).to(torch.long) bins = _bincount(unique_mapping, minlength=num_classes**2) return bins.reshape(num_classes, num_classes) -def _multiclass_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: +def _multiclass_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: """Reduces the confusion matrix to it's final form. Normalization technique can be chosen by ``normalize``. @@ -248,9 +333,56 @@ def multiclass_confusion_matrix( target: Tensor, num_classes: int, ignore_index: Optional[int] = None, - normalize: Optional[str] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, validate_args: bool = True, ) -> Tensor: + r""" + Computes the `confusion matrix`_ for multiclass tasks. + + Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (pred is integer tensor): + >>> from torchmetrics.functional import multiclass_confusion_matrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_confusion_matrix(preds, target, num_classes=3) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics.functional import multiclass_confusion_matrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58] + ... [0.22, 0.61, 0.17] + ... [0.71, 0.09, 0.20] + ... [0.82, 0.05, 0.13] + ... ]) + >>> multiclass_confusion_matrix(preds, target, num_classes=3) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + """ if validate_args: _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) @@ -260,9 +392,19 @@ def multiclass_confusion_matrix( def _multilabel_confusion_matrix_arg_validation( - num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None, normalize: Optional[str] = None + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, ) -> None: - if not isinstance(num_labels, int) and num_labels < 2: + """Validate non tensor input. + + - ``num_labels`` should be an int larger than 1 + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(num_labels, int) or num_labels < 2: raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") if not isinstance(threshold, float): raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") @@ -276,7 +418,13 @@ def _multilabel_confusion_matrix_arg_validation( def _multilabel_confusion_matrix_tensor_validation( preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None ) -> None: - """Validate tensor input.""" + """Validate tensor input. + + - tensors have to be of same shape + - the second dimension of both tensors need to be equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + """ # Check that they have same shape _check_same_shape(preds, target) @@ -311,6 +459,12 @@ def _multilabel_confusion_matrix_tensor_validation( def _multilabel_confusion_matrix_format( preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None ) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all elements that should be ignored with negative numbers for later filtration + """ if preds.is_floating_point(): if not torch.all((0 <= preds) * (preds <= 1)): preds = preds.sigmoid() @@ -321,7 +475,8 @@ def _multilabel_confusion_matrix_format( if ignore_index is not None: preds = preds.clone() target = target.clone() - # make sure that when we map, it will always result in a negative number that we can filter away + # Make sure that when we map, it will always result in a negative number that we can filter away + # Each label correspond to a 2x2 matrix = 4 elements per label idx = target == ignore_index preds[idx] = -4 * num_labels target[idx] = -4 * num_labels @@ -330,13 +485,20 @@ def _multilabel_confusion_matrix_format( def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: + """Computes the bins to update the confusion matrix with.""" unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() unique_mapping = unique_mapping[unique_mapping >= 0] bins = _bincount(unique_mapping, minlength=4 * num_labels) return bins.reshape(num_labels, 2, 2) -def _multilabel_confusion_matrix_compute(confmat: Tensor, normalize: Optional[str] = None) -> Tensor: +def _multilabel_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ return _confusion_matrix_reduce(confmat, normalize) @@ -346,9 +508,52 @@ def multilabel_confusion_matrix( num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None, - normalize: Optional[str] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, validate_args: bool = True, ) -> Tensor: + r""" + Computes the `confusion matrix`_ for multilabel tasks. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_confusion_matrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_confusion_matrix(preds, target, num_labels=3) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_confusion_matrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_confusion_matrix(preds, target, num_labels=3) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ if validate_args: _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 9a8ca40cdca..a5c53e66cd2 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -141,7 +141,7 @@ def _multiclass_stat_scores_arg_validation( multidim_average: str = "global", ignore_index: Optional[int] = None, ) -> None: - if not isinstance(num_classes, int) and num_classes < 2: + if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") if not isinstance(top_k, int) and top_k < 1: raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") From ae9ca58d2289647b103c18f8530ab32773a7bb92 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Jun 2022 11:32:05 +0200 Subject: [PATCH 44/74] add docstring to modular confusion matrix --- .../classification/confusion_matrix.py | 153 +++++++++++++++++- .../classification/confusion_matrix.py | 2 +- 2 files changed, 153 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 02cd9dcf399..253317ec532 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -39,6 +39,46 @@ class BinaryConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for binary tasks. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics import BinaryConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryConfusionMatrix() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics import BinaryConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryConfusionMatrix() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -47,7 +87,7 @@ def __init__( self, threshold: float = 0.5, ignore_index: Optional[int] = None, - normalize: Optional[str] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, validate_args: bool = True, **kwargs: Any, ) -> None: @@ -62,6 +102,12 @@ def __init__( self.add_state("confmat", torch.zeros(2, 2), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ if self.validate_args: _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) @@ -69,10 +115,58 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confmat += confmat def compute(self) -> Tensor: + """Computes confusion matrix. Returns an [2,2] matrix. """ return _binary_confusion_matrix_compute(self.confmat, self.normalize) class MulticlassConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for multiclass tasks. + + Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (pred is integer tensor): + >>> from torchmetrics import MulticlassConfusionMatrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics import MulticlassConfusionMatrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58] + ... [0.22, 0.61, 0.17] + ... [0.71, 0.09, 0.20] + ... [0.82, 0.05, 0.13] + ... ]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -96,6 +190,12 @@ def __init__( self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ if self.validate_args: _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) @@ -103,10 +203,54 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confmat += confmat def compute(self) -> Tensor: + """Computes confusion matrix. Returns an [num_classes, num_classes] matrix. """ return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) class MultilabelConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for multilabel tasks. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics import MultilabelConfusionMatrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics import MultilabelConfusionMatrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -132,6 +276,12 @@ def __init__( self.add_state("confmat", torch.zeros(num_labels, 2, 2), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ if self.validate_args: _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) preds, target = _multilabel_confusion_matrix_format( @@ -141,6 +291,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confmat += confmat def compute(self) -> Tensor: + """Computes confusion matrix. Returns an [num_labels,2,2] matrix. """ return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 3ccde466346..de929461ca2 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -524,7 +524,7 @@ def multilabel_confusion_matrix( Args: preds: Tensor with predictions target: Tensor with true labels - num_classes: Integer specifing the number of labels + num_labels: Integer specifing the number of labels threshold: Threshold for transforming probability to binary (0,1) predictions ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation From afcde9c0560b4c13a02a7e419f2ca66c2eaddf79 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Jun 2022 09:32:51 +0000 Subject: [PATCH 45/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/confusion_matrix.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 253317ec532..728d6d9133c 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -115,7 +115,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confmat += confmat def compute(self) -> Tensor: - """Computes confusion matrix. Returns an [2,2] matrix. """ + """Computes confusion matrix. + + Returns an [2,2] matrix. + """ return _binary_confusion_matrix_compute(self.confmat, self.normalize) @@ -203,7 +206,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confmat += confmat def compute(self) -> Tensor: - """Computes confusion matrix. Returns an [num_classes, num_classes] matrix. """ + """Computes confusion matrix. + + Returns an [num_classes, num_classes] matrix. + """ return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) @@ -291,7 +297,10 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self.confmat += confmat def compute(self) -> Tensor: - """Computes confusion matrix. Returns an [num_labels,2,2] matrix. """ + """Computes confusion matrix. + + Returns an [num_labels,2,2] matrix. + """ return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) From 24bc7f9c93045f039132bcc97c519d35aba20e1b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Jun 2022 19:52:36 +0200 Subject: [PATCH 46/74] add docstring to functional stat scores --- .../classification/confusion_matrix.py | 17 +- .../functional/classification/stat_scores.py | 374 +++++++++++++++++- 2 files changed, 366 insertions(+), 25 deletions(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index de929461ca2..b82393b59e4 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -187,6 +187,9 @@ def binary_confusion_matrix( validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + Returns: + A ``[2, 2]`` tensor + Example (preds is int tensor): >>> from torchmetrics.functional import binary_confusion_matrix >>> target = torch.tensor([1, 1, 0, 0]) @@ -360,6 +363,9 @@ def multiclass_confusion_matrix( validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + Returns: + A ``[num_classes, num_classes]`` tensor + Example (pred is integer tensor): >>> from torchmetrics.functional import multiclass_confusion_matrix >>> target = torch.tensor([2, 1, 0, 0]) @@ -373,10 +379,10 @@ def multiclass_confusion_matrix( >>> from torchmetrics.functional import multiclass_confusion_matrix >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ - ... [0.16, 0.26, 0.58] - ... [0.22, 0.61, 0.17] - ... [0.71, 0.09, 0.20] - ... [0.82, 0.05, 0.13] + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.82, 0.05, 0.13], ... ]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], @@ -536,6 +542,9 @@ def multilabel_confusion_matrix( validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + Returns: + A ``[num_labels, 2, 2]`` tensor + Example (preds is int tensor): >>> from torchmetrics.functional import multilabel_confusion_matrix >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index a5c53e66cd2..1262e896110 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -24,12 +24,17 @@ def _binary_stat_scores_arg_validation( threshold: float = 0.5, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> None: - """Validate non tensor input.""" - if not isinstance(threshold, float): - raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(threshold, float) and not (0 <= threshold <= 1): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") allowed_multidim_average = ("global", "samplewise") if multidim_average not in allowed_multidim_average: raise ValueError( @@ -42,10 +47,16 @@ def _binary_stat_scores_arg_validation( def _binary_stat_scores_tensor_validation( preds: Tensor, target: Tensor, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> None: - """Validate tensor input.""" + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 2 dimensional + """ # Check that they have same shape _check_same_shape(preds, target) @@ -80,7 +91,12 @@ def _binary_stat_scores_format( threshold: float = 0.5, ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor]: - """Brings the prediction and target tensors to a unified format (Flattened and class indices).""" + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all datapoints that should be ignored with negative values + """ if preds.is_floating_point(): if not torch.all((0 <= preds) * (preds <= 1)): # preds is logits, convert with sigmoid @@ -101,9 +117,9 @@ def _binary_stat_scores_format( def _binary_stat_scores_update( preds: Tensor, target: Tensor, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """""" + """Computes the statistics.""" sum_dim = [0, 1] if multidim_average == "global" else 1 tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() @@ -113,8 +129,9 @@ def _binary_stat_scores_update( def _binary_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: str = "global" + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: Literal["global", "samplewise"] = "global" ) -> Tensor: + """Stack statistics and compute support also.""" return torch.stack([tp, fp, tn, fn, tp + fn], dim=0 if multidim_average == "global" else 1).squeeze() @@ -126,6 +143,69 @@ def binary_stat_scores( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for binary tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on the ``multidim_average`` parameter: + + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + + Example (multidim tensors): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + """ if validate_args: _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) @@ -137,10 +217,18 @@ def binary_stat_scores( def _multiclass_stat_scores_arg_validation( num_classes: int, top_k: int = 1, - average: str = "micro", - multidim_average: str = "global", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``top_k`` has to be an int larger than 0 but no larger than number of classes + - ``average`` has to be "micro" | "macro" | "weighted" | "none" + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") if not isinstance(top_k, int) and top_k < 1: @@ -165,9 +253,19 @@ def _multiclass_stat_scores_tensor_validation( preds: Tensor, target: Tensor, num_classes: int, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> None: + """Validate tensor input. + + - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - if preds and target have same number of dims, then all dimensions should match + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 2 dimensional in the + int case and 3 dimensional in the float case + - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1} + - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1} + """ if preds.ndim == target.ndim + 1: if not preds.is_floating_point(): raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") @@ -230,6 +328,11 @@ def _multiclass_stat_scores_format( target: Tensor, top_k: int = 1, ) -> Tuple[Tensor, Tensor]: + """ "Convert all input to label format except if ``top_k`` is not 1. + + - Applies argmax if preds have one more dimension than target + - Flattens additional dimensions + """ # Apply argmax if we have one more dimension if preds.ndim == target.ndim + 1 and top_k == 1: preds = preds.argmax(dim=1) @@ -246,9 +349,18 @@ def _multiclass_stat_scores_update( target: Tensor, num_classes: int, top_k: int = 1, - multidim_average: str = "global", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics. + + - If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and + target into one hot format. + - Else we calculate statistics by first calculating the confusion matrix and afterwards deriving the + statistics from that + - Remove all datapoints that should be ignored. Depending on if ``ignore_index`` is in the set of labels + or outside we have do use different augmentation stategies when one hot encoding. + """ if multidim_average == "samplewise" or top_k != 1: ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None if ignore_index is not None and not ignore_in: @@ -298,9 +410,14 @@ def _multiclass_stat_scores_update( def _multiclass_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", ) -> Tensor: - + """Stack statistics and compute support also. Applies average strategy afterwards.""" res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": @@ -327,6 +444,99 @@ def multiclass_stat_scores( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multiclass tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.82, 0.05, 0.13], + ... ]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([2, 1, 2, 1, 3]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[2, 0, 2, 0, 2], + [1, 0, 3, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise') + tensor([[3, 3, 9, 3, 6], + [2, 4, 8, 4, 6]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[[2, 1, 3, 0, 2], + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) + """ if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) @@ -338,11 +548,19 @@ def multiclass_stat_scores( def _multilabel_stat_scores_arg_validation( num_labels: int, threshold: float = 0.5, - average: str = "micro", - multidim_average: str = "global", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, ) -> None: - if not isinstance(num_labels, int) and num_labels < 2: + """Validate non tensor input. + + - ``num_labels`` should be an int larger than 1 + - ``threshold`` has to be a float in the [0,1] range + - ``average`` has to be "micro" | "macro" | "weighted" | "none" + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_labels, int) or num_labels < 2: raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") if not isinstance(threshold, float): raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") @@ -365,6 +583,14 @@ def _multilabel_stat_scores_tensor_validation( multidim_average: str, ignore_index: Optional[int] = None, ) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - the second dimension of both tensors need to be equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 3 dimensional + """ # Check that they have same shape _check_same_shape(preds, target) @@ -402,6 +628,12 @@ def _multilabel_stat_scores_tensor_validation( def _multilabel_stat_scores_format( preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None ) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all elements that should be ignored with negative numbers for later filtration + """ if preds.is_floating_point(): if not torch.all((0 <= preds) * (preds <= 1)): preds = preds.sigmoid() @@ -418,8 +650,9 @@ def _multilabel_stat_scores_format( def _multilabel_stat_scores_update( - preds: Tensor, target: Tensor, multidim_average: str = "global" + preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global" ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics.""" sum_dim = [0, -1] if multidim_average == "global" else [-1] tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() @@ -429,8 +662,14 @@ def _multilabel_stat_scores_update( def _multilabel_stat_scores_compute( - tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, average: str = "micro", multidim_average: str = "global" + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", ) -> Tensor: + """Stack statistics and compute support also. Applies average strategy afterwards.""" res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": @@ -454,6 +693,99 @@ def multilabel_stat_scores( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multilabel tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[[1, 1, 0, 0, 1], + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) + + """ if validate_args: _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) From 3116c4a610992eb0e00ab2728aa97bbb01be4f53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Jun 2022 17:54:14 +0000 Subject: [PATCH 47/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../functional/classification/stat_scores.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 1262e896110..4692844bf6d 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -145,7 +145,7 @@ def binary_stat_scores( ) -> Tensor: r""" Computes the number of true positives, false positives, true negatives, false negatives and the support - for binary tasks. Related to `Type I and Type II errors`_. + for binary tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside @@ -153,7 +153,7 @@ def binary_stat_scores( we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, ...)`` The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` - argument. + argument. Args: preds: Tensor with predictions @@ -176,7 +176,7 @@ def binary_stat_scores( to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on the ``multidim_average`` parameter: - - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` Example (preds is int tensor): @@ -417,7 +417,10 @@ def _multiclass_stat_scores_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ) -> Tensor: - """Stack statistics and compute support also. Applies average strategy afterwards.""" + """Stack statistics and compute support also. + + Applies average strategy afterwards. + """ res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": @@ -446,7 +449,7 @@ def multiclass_stat_scores( ) -> Tensor: r""" Computes the number of true positives, false positives, true negatives, false negatives and the support - for multiclass tasks. Related to `Type I and Type II errors`_. + for multiclass tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point @@ -467,7 +470,7 @@ def multiclass_stat_scores( - ``macro``: Calculate statistics for each label and average them - ``weighted``: Calculates statistics for each label and computes weighted average using their support - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction - top_k: + top_k: Number of highest probability or logit score predictions considered to find the correct label. Only works when ``preds`` contain probabilities/logits. multidim_average: @@ -669,7 +672,10 @@ def _multilabel_stat_scores_compute( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ) -> Tensor: - """Stack statistics and compute support also. Applies average strategy afterwards.""" + """Stack statistics and compute support also. + + Applies average strategy afterwards. + """ res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) sum_dim = 0 if multidim_average == "global" else 1 if average == "micro": @@ -695,7 +701,7 @@ def multilabel_stat_scores( ) -> Tensor: r""" Computes the number of true positives, false positives, true negatives, false negatives and the support - for multilabel tasks. Related to `Type I and Type II errors`_. + for multilabel tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside @@ -703,7 +709,7 @@ def multilabel_stat_scores( we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, C, ...)`` The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` - argument. + argument. Args: preds: Tensor with predictions From 6f0320a8f10ed265bc9ad32b6c598fb1bf1e6d13 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 29 Jun 2022 20:01:49 +0200 Subject: [PATCH 48/74] add docstring to modular stat scores --- .../classification/stat_scores.py | 274 ++++++++++++++++++ 1 file changed, 274 insertions(+) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index df2b2e3227e..0482f0cea0d 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -42,6 +42,7 @@ class AbstractStatScores(Metric): # define common functions def _create_state(self, size: int, multidim_average: str) -> None: + """Initialize the states for the different statistics.""" if multidim_average == "samplewise": default: Callable[[], list] = lambda: [] dist_reduce_fx = "cat" @@ -54,6 +55,7 @@ def _create_state(self, size: int, multidim_average: str) -> None: self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: + """ Update states depending on multidim_average argument.""" if self.multidim_average == "samplewise": self.tp.append(tp) self.fp.append(fp) @@ -66,6 +68,7 @@ def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: self.fn += fn def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ Final aggregation in case of list states.""" tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn @@ -74,6 +77,59 @@ def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: class BinaryStatScores(AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for binary tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + + Example (multidim tensors): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -97,6 +153,12 @@ def __init__( self._create_state(1, multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ if self.validate_args: _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) @@ -104,11 +166,102 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: + """ + Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on the ``multidim_average`` parameter: + + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + + """ tp, fp, tn, fn = self._final_state() return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) class MulticlassStatScores(AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multiclass tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.82, 0.05, 0.13], + ... ]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([2, 1, 2, 1, 3]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[2, 0, 2, 0, 2], + [1, 0, 3, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise') + tensor([[3, 3, 9, 3, 6], + [2, 4, 8, 4, 6]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[[2, 1, 3, 0, 2], + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) + """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -136,6 +289,12 @@ def __init__( self._create_state(num_classes, multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ if self.validate_args: _multiclass_stat_scores_tensor_validation( preds, target, self.num_classes, self.multidim_average, self.ignore_index @@ -147,11 +306,105 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: + """ + Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + """ tp, fp, tn, fn = self._final_state() return _multiclass_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) class MultilabelStatScores(AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multilabel tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[[1, 1, 0, 0, 1], + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) + + """ is_differentiable: bool = False higher_is_better: Optional[bool] = None full_state_update: bool = False @@ -179,6 +432,12 @@ def __init__( self._create_state(num_labels, multidim_average) def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ if self.validate_args: _multilabel_stat_scores_tensor_validation( preds, target, self.num_labels, self.multidim_average, self.ignore_index @@ -190,6 +449,21 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: + """ + Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + """ tp, fp, tn, fn = self._final_state() return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) From 1e26154632086a5c2bb83b5bbae75b2101f5a4c6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Jun 2022 18:04:11 +0000 Subject: [PATCH 49/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../classification/stat_scores.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 0482f0cea0d..6d6a58edd4a 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -55,7 +55,7 @@ def _create_state(self, size: int, multidim_average: str) -> None: self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: - """ Update states depending on multidim_average argument.""" + """Update states depending on multidim_average argument.""" if self.multidim_average == "samplewise": self.tp.append(tp) self.fp.append(fp) @@ -68,7 +68,7 @@ def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: self.fn += fn def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """ Final aggregation in case of list states.""" + """Final aggregation in case of list states.""" tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn @@ -79,7 +79,7 @@ def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: class BinaryStatScores(AbstractStatScores): r""" Computes the number of true positives, false positives, true negatives, false negatives and the support - for binary tasks. Related to `Type I and Type II errors`_. + for binary tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside @@ -87,7 +87,7 @@ class BinaryStatScores(AbstractStatScores): we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, ...)`` The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` - argument. + argument. Args: threshold: Threshold for transforming probability to binary {0,1} predictions @@ -166,17 +166,15 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: - """ - Computes the final statistics. + """Computes the final statistics. Returns: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape depends on the ``multidim_average`` parameter: - - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` - """ tp, fp, tn, fn = self._final_state() return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) @@ -185,7 +183,7 @@ def compute(self) -> Tensor: class MulticlassStatScores(AbstractStatScores): r""" Computes the number of true positives, false positives, true negatives, false negatives and the support - for multiclass tasks. Related to `Type I and Type II errors`_. + for multiclass tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point @@ -204,7 +202,7 @@ class MulticlassStatScores(AbstractStatScores): - ``macro``: Calculate statistics for each label and average them - ``weighted``: Calculates statistics for each label and computes weighted average using their support - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction - top_k: + top_k: Number of highest probability or logit score predictions considered to find the correct label. Only works when ``preds`` contain probabilities/logits. multidim_average: @@ -306,8 +304,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: - """ - Computes the final statistics. + """Computes the final statistics. Returns: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds @@ -328,7 +325,7 @@ def compute(self) -> Tensor: class MultilabelStatScores(AbstractStatScores): r""" Computes the number of true positives, false positives, true negatives, false negatives and the support - for multilabel tasks. Related to `Type I and Type II errors`_. + for multilabel tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside @@ -336,7 +333,7 @@ class MultilabelStatScores(AbstractStatScores): we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, C, ...)`` The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` - argument. + argument. Args: num_labels: Integer specifing the number of labels @@ -449,8 +446,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore self._update_state(tp, fp, tn, fn) def compute(self) -> Tensor: - """ - Computes the final statistics. + """Computes the final statistics. Returns: The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds From fd2325e56e4592b955251234302840c4c682076f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 11:00:50 +0200 Subject: [PATCH 50/74] make private --- src/torchmetrics/classification/stat_scores.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 6d6a58edd4a..4b94d7b82da 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -39,7 +39,7 @@ from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod -class AbstractStatScores(Metric): +class _AbstractStatScores(Metric): # define common functions def _create_state(self, size: int, multidim_average: str) -> None: """Initialize the states for the different statistics.""" @@ -76,7 +76,7 @@ def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: return tp, fp, tn, fn -class BinaryStatScores(AbstractStatScores): +class BinaryStatScores(_AbstractStatScores): r""" Computes the number of true positives, false positives, true negatives, false negatives and the support for binary tasks. Related to `Type I and Type II errors`_. @@ -180,7 +180,7 @@ def compute(self) -> Tensor: return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) -class MulticlassStatScores(AbstractStatScores): +class MulticlassStatScores(_AbstractStatScores): r""" Computes the number of true positives, false positives, true negatives, false negatives and the support for multiclass tasks. Related to `Type I and Type II errors`_. @@ -322,7 +322,7 @@ def compute(self) -> Tensor: return _multiclass_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) -class MultilabelStatScores(AbstractStatScores): +class MultilabelStatScores(_AbstractStatScores): r""" Computes the number of true positives, false positives, true negatives, false negatives and the support for multilabel tasks. Related to `Type I and Type II errors`_. From f8d18dcf87ee48595c2a1923ea1ff5a8a9f86764 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 11:09:56 +0200 Subject: [PATCH 51/74] docs --- .../source/classification/confusion_matrix.rst | 18 ++++++++++++++++++ docs/source/classification/stat_scores.rst | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/docs/source/classification/confusion_matrix.rst b/docs/source/classification/confusion_matrix.rst index a1cc43fdfe9..8a219e46a1d 100644 --- a/docs/source/classification/confusion_matrix.rst +++ b/docs/source/classification/confusion_matrix.rst @@ -15,8 +15,26 @@ ________________ .. autoclass:: torchmetrics.ConfusionMatrix :noindex: +.. autoclass:: torchmetrics.BinaryConfusionMatrix + :noindex: + +.. autoclass:: torchmetrics.MulticlassConfusionMatrix + :noindex: + +.. autoclass:: torchmetrics.MultilabelConfusionMatrix + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.confusion_matrix :noindex: + +.. autofunction:: torchmetrics.functional.binary_confusion_matrix + :noindex: + +.. autofunction:: torchmetrics.functional.multiclass_confusion_matrix + :noindex: + +.. autofunction:: torchmetrics.functional.multilabel_confusion_matrix + :noindex: diff --git a/docs/source/classification/stat_scores.rst b/docs/source/classification/stat_scores.rst index 809c3106948..2ec7e35778b 100644 --- a/docs/source/classification/stat_scores.rst +++ b/docs/source/classification/stat_scores.rst @@ -15,8 +15,26 @@ ________________ .. autoclass:: torchmetrics.StatScores :noindex: +.. autoclass:: torchmetrics.BinaryStatScores + :noindex: + +.. autoclass:: torchmetrics.MulticlassStatScores + :noindex: + +.. autoclass:: torchmetrics.MultilabelStatScores + :noindex: + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.stat_scores :noindex: + +.. autofunction:: torchmetrics.functional.binary_stat_scores + :noindex: + +.. autofunction:: torchmetrics.functional.multiclass_stat_scores + :noindex: + +.. autofunction:: torchmetrics.functional.multilabel_stat_scores + :noindex: From 40b813f116ec339b9cea1f16523ffda5a1cc3b62 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 11:36:32 +0200 Subject: [PATCH 52/74] fix mypy and doctests --- src/torchmetrics/__init__.py | 12 +++++ .../classification/confusion_matrix.py | 14 +++--- .../classification/stat_scores.py | 47 +++++++++---------- .../classification/confusion_matrix.py | 2 +- .../functional/classification/stat_scores.py | 36 +++++++------- 5 files changed, 60 insertions(+), 51 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 549367ce4da..5f62e270e9a 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -26,6 +26,8 @@ ROC, Accuracy, AveragePrecision, + BinaryConfusionMatrix, + BinaryStatScores, BinnedAveragePrecision, BinnedPrecisionRecallCurve, BinnedRecallAtFixedPrecision, @@ -43,6 +45,10 @@ LabelRankingAveragePrecision, LabelRankingLoss, MatthewsCorrCoef, + MulticlassConfusionMatrix, + MulticlassStatScores, + MultilabelConfusionMatrix, + MultilabelStatScores, Precision, PrecisionRecallCurve, Recall, @@ -125,6 +131,9 @@ "CHRFScore", "CohenKappa", "ConfusionMatrix", + "BinaryConfusionMatrix", + "MulticlassConfusionMatrix", + "MultilabelConfusionMatrix", "CosineSimilarity", "CoverageError", "Dice", @@ -185,6 +194,9 @@ "SQuAD", "StructuralSimilarityIndexMeasure", "StatScores", + "BinaryStatScores", + "MulticlassStatScores", + "MultilabelStatScores", "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 728d6d9133c..fd06b2af88e 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -99,7 +99,7 @@ def __init__( self.normalize = normalize self.validate_args = validate_args - self.add_state("confmat", torch.zeros(2, 2), dist_reduce_fx="sum") + self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. @@ -159,10 +159,10 @@ class MulticlassConfusionMatrix(Metric): >>> from torchmetrics import MulticlassConfusionMatrix >>> target = torch.tensor([2, 1, 0, 0]) >>> preds = torch.tensor([ - ... [0.16, 0.26, 0.58] - ... [0.22, 0.61, 0.17] - ... [0.71, 0.09, 0.20] - ... [0.82, 0.05, 0.13] + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], ... ]) >>> metric = MulticlassConfusionMatrix(num_classes=3) >>> metric(preds, target) @@ -190,7 +190,7 @@ def __init__( self.normalize = normalize self.validate_args = validate_args - self.add_state("confmat", torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") + self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. @@ -279,7 +279,7 @@ def __init__( self.normalize = normalize self.validate_args = validate_args - self.add_state("confmat", torch.zeros(num_labels, 2, 2), dist_reduce_fx="sum") + self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum") def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore """Update state with predictions and targets. diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 4b94d7b82da..4b0c850a124 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Literal, Optional, Tuple +from typing import Any, Callable, Literal, Optional, Tuple, Union import torch from torch import Tensor @@ -43,11 +43,12 @@ class _AbstractStatScores(Metric): # define common functions def _create_state(self, size: int, multidim_average: str) -> None: """Initialize the states for the different statistics.""" + default: Union[Callable[[], list], Callable[[], Tensor]] if multidim_average == "samplewise": - default: Callable[[], list] = lambda: [] + default = lambda: [] dist_reduce_fx = "cat" else: - default: Callable[[], Tensor] = lambda: torch.zeros(size, dtype=torch.long) + default = lambda: torch.zeros(size, dtype=torch.long) dist_reduce_fx = "sum" self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) @@ -114,7 +115,7 @@ class BinaryStatScores(_AbstractStatScores): >>> from torchmetrics.functional import binary_stat_scores >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) - >>> binary_stat_scores(preds, target, num_labels=3) + >>> binary_stat_scores(preds, target) tensor([2, 1, 2, 1, 3]) Example (multidim tensors): @@ -126,7 +127,7 @@ class BinaryStatScores(_AbstractStatScores): ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], ... ] ... ) - >>> binary_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + >>> binary_stat_scores(preds, target, multidim_average='samplewise') tensor([[2, 3, 0, 1, 3], [0, 2, 1, 3, 3]]) """ @@ -235,13 +236,13 @@ class MulticlassStatScores(_AbstractStatScores): ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], - ... [0.82, 0.05, 0.13], + ... [0.05, 0.82, 0.13], ... ]) >>> multiclass_stat_scores(preds, target, num_classes=3) - tensor([2, 1, 2, 1, 3]) + tensor([3, 1, 7, 1, 4]) >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) - tensor([[2, 0, 2, 0, 2], - [1, 0, 3, 0, 1], + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], [1, 0, 3, 0, 1]]) Example (multidim tensors): @@ -250,15 +251,14 @@ class MulticlassStatScores(_AbstractStatScores): >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise') tensor([[3, 3, 9, 3, 6], - [2, 4, 8, 4, 6]]) + [2, 4, 8, 4, 6]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[[2, 1, 3, 0, 2], - [0, 1, 3, 2, 2], - [1, 1, 3, 1, 2]], - - [[0, 1, 4, 1, 1], - [1, 1, 2, 2, 3], - [1, 2, 2, 1, 2]]]) + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) """ is_differentiable: bool = False higher_is_better: Optional[bool] = None @@ -268,7 +268,7 @@ def __init__( self, num_classes: int, top_k: int = 1, - average: Literal["micro", "macro", "samples"] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, @@ -394,12 +394,11 @@ class MultilabelStatScores(_AbstractStatScores): [0, 2, 1, 3, 3]]) >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[[1, 1, 0, 0, 1], - [1, 1, 0, 0, 1], - [0, 1, 0, 1, 1]], - - [[0, 0, 0, 2, 2], - [0, 2, 0, 0, 0], - [0, 0, 1, 1, 1]]]) + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) """ is_differentiable: bool = False @@ -410,7 +409,7 @@ def __init__( self, num_labels: int, threshold: float = 0.5, - average: Literal["micro", "macro", "samples"] = "micro", + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", multidim_average: Literal["global", "samplewise"] = "global", ignore_index: Optional[int] = None, validate_args: bool = True, diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index b82393b59e4..a16510b6312 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -382,7 +382,7 @@ def multiclass_confusion_matrix( ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], - ... [0.82, 0.05, 0.13], + ... [0.05, 0.82, 0.13], ... ]) >>> multiclass_confusion_matrix(preds, target, num_classes=3) tensor([[1, 1, 0], diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 4692844bf6d..2bc3f18844e 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -190,7 +190,7 @@ def binary_stat_scores( >>> from torchmetrics.functional import binary_stat_scores >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) - >>> binary_stat_scores(preds, target, num_labels=3) + >>> binary_stat_scores(preds, target) tensor([2, 1, 2, 1, 3]) Example (multidim tensors): @@ -202,7 +202,7 @@ def binary_stat_scores( ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], ... ] ... ) - >>> binary_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + >>> binary_stat_scores(preds, target, multidim_average='samplewise') tensor([[2, 3, 0, 1, 3], [0, 2, 1, 3, 3]]) """ @@ -515,13 +515,13 @@ def multiclass_stat_scores( ... [0.16, 0.26, 0.58], ... [0.22, 0.61, 0.17], ... [0.71, 0.09, 0.20], - ... [0.82, 0.05, 0.13], + ... [0.05, 0.82, 0.13], ... ]) >>> multiclass_stat_scores(preds, target, num_classes=3) - tensor([2, 1, 2, 1, 3]) + tensor([3, 1, 7, 1, 4]) >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) - tensor([[2, 0, 2, 0, 2], - [1, 0, 3, 0, 1], + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], [1, 0, 3, 0, 1]]) Example (multidim tensors): @@ -530,15 +530,14 @@ def multiclass_stat_scores( >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise') tensor([[3, 3, 9, 3, 6], - [2, 4, 8, 4, 6]]) + [2, 4, 8, 4, 6]]) >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) tensor([[[2, 1, 3, 0, 2], - [0, 1, 3, 2, 2], - [1, 1, 3, 1, 2]], - - [[0, 1, 4, 1, 1], - [1, 1, 2, 2, 3], - [1, 2, 2, 1, 2]]]) + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) """ if validate_args: _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) @@ -784,12 +783,11 @@ def multilabel_stat_scores( [0, 2, 1, 3, 3]]) >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) tensor([[[1, 1, 0, 0, 1], - [1, 1, 0, 0, 1], - [0, 1, 0, 1, 1]], - - [[0, 0, 0, 2, 2], - [0, 2, 0, 0, 0], - [0, 0, 1, 1, 1]]]) + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) """ if validate_args: From f79a8893dbb4878a6ae25fcf323ac635e653514f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 13:32:02 +0200 Subject: [PATCH 53/74] fix docs formatting --- .../classification/confusion_matrix.rst | 27 +++++++++++++++++++ docs/source/classification/stat_scores.rst | 27 +++++++++++++++++++ .../classification/confusion_matrix.py | 9 +++++++ .../classification/stat_scores.py | 6 +++++ .../classification/confusion_matrix.py | 9 +++++++ .../functional/classification/stat_scores.py | 11 ++++++++ 6 files changed, 89 insertions(+) diff --git a/docs/source/classification/confusion_matrix.rst b/docs/source/classification/confusion_matrix.rst index 8a219e46a1d..bde2207e043 100644 --- a/docs/source/classification/confusion_matrix.rst +++ b/docs/source/classification/confusion_matrix.rst @@ -12,29 +12,56 @@ Confusion Matrix Module Interface ________________ +ConfusionMatrix +^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.ConfusionMatrix :noindex: +BinaryConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.BinaryConfusionMatrix :noindex: + :exclude-members: update, compute + +MulticlassConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: torchmetrics.MulticlassConfusionMatrix :noindex: + :exclude-members: update, compute + +MultilabelConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: torchmetrics.MultilabelConfusionMatrix :noindex: + :exclude-members: update, compute Functional Interface ____________________ +confusion_matrix +^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.confusion_matrix :noindex: +binary_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.binary_confusion_matrix :noindex: +multiclass_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.multiclass_confusion_matrix :noindex: +multilabel_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.multilabel_confusion_matrix :noindex: diff --git a/docs/source/classification/stat_scores.rst b/docs/source/classification/stat_scores.rst index 2ec7e35778b..382e6048534 100644 --- a/docs/source/classification/stat_scores.rst +++ b/docs/source/classification/stat_scores.rst @@ -12,29 +12,56 @@ Stat Scores Module Interface ________________ +StatScores +^^^^^^^^^^ + .. autoclass:: torchmetrics.StatScores :noindex: +BinaryStatScores +^^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.BinaryStatScores :noindex: + :exclude-members: update, compute + +MulticlassStatScores +^^^^^^^^^^^^^^^^^^^^ .. autoclass:: torchmetrics.MulticlassStatScores :noindex: + :exclude-members: update, compute + +MultilabelStatScores +^^^^^^^^^^^^^^^^^^^^ .. autoclass:: torchmetrics.MultilabelStatScores :noindex: + :exclude-members: update, compute Functional Interface ____________________ +stat_scores +^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.stat_scores :noindex: +binary_stat_scores +^^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.binary_stat_scores :noindex: +multiclass_stat_scores +^^^^^^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.multiclass_stat_scores :noindex: +multilabel_stat_scores +^^^^^^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.multilabel_stat_scores :noindex: diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index fd06b2af88e..0785a7a419a 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -43,10 +43,12 @@ class BinaryConfusionMatrix(Metric): Computes the `confusion matrix`_ for binary tasks. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. Args: @@ -54,6 +56,7 @@ class BinaryConfusionMatrix(Metric): ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions @@ -127,10 +130,12 @@ class MulticlassConfusionMatrix(Metric): Computes the `confusion matrix`_ for multiclass tasks. Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into an int tensor. - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. Args: @@ -138,6 +143,7 @@ class MulticlassConfusionMatrix(Metric): ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions @@ -218,10 +224,12 @@ class MultilabelConfusionMatrix(Metric): Computes the `confusion matrix`_ for multilabel tasks. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, C, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. Args: @@ -230,6 +238,7 @@ class MultilabelConfusionMatrix(Metric): ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 4b0c850a124..724c23312ff 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -83,10 +83,12 @@ class BinaryStatScores(_AbstractStatScores): for binary tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` argument. @@ -187,10 +189,12 @@ class MulticlassStatScores(_AbstractStatScores): for multiclass tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into an int tensor. - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` argument. @@ -328,10 +332,12 @@ class MultilabelStatScores(_AbstractStatScores): for multilabel tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, C, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` argument. diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index a16510b6312..c7e9eda5f8e 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -167,10 +167,12 @@ def binary_confusion_matrix( Computes the `confusion matrix`_ for binary tasks. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. Args: @@ -180,6 +182,7 @@ def binary_confusion_matrix( ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions @@ -343,10 +346,12 @@ def multiclass_confusion_matrix( Computes the `confusion matrix`_ for multiclass tasks. Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into an int tensor. - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. Args: @@ -356,6 +361,7 @@ def multiclass_confusion_matrix( ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions @@ -521,10 +527,12 @@ def multilabel_confusion_matrix( Computes the `confusion matrix`_ for multilabel tasks. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, C, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. Args: @@ -535,6 +543,7 @@ def multilabel_confusion_matrix( ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation normalize: Normalization mode for confusion matrix. Choose from: + - ``None`` or ``'none'``: no normalization (default) - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 2bc3f18844e..6c1d1b6813a 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -148,10 +148,12 @@ def binary_stat_scores( for binary tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` argument. @@ -452,10 +454,12 @@ def multiclass_stat_scores( for multiclass tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into an int tensor. - ``target`` (int tensor): ``(N, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` argument. @@ -491,8 +495,10 @@ def multiclass_stat_scores( depends on ``average`` and ``multidim_average`` parameters: - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` @@ -703,10 +709,12 @@ def multilabel_stat_scores( for multilabel tasks. Related to `Type I and Type II errors`_. Accepts the following input tensors: + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (int tensor): ``(N, C, ...)`` + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` argument. @@ -741,9 +749,12 @@ def multilabel_stat_scores( depends on ``average`` and ``multidim_average`` parameters: - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` From 6f5ed75c76b206324a00d96c7cbc11aa99ff573c Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 13:52:31 +0200 Subject: [PATCH 54/74] literal backwards --- src/torchmetrics/classification/confusion_matrix.py | 3 ++- src/torchmetrics/classification/stat_scores.py | 3 ++- src/torchmetrics/functional/classification/confusion_matrix.py | 3 ++- src/torchmetrics/functional/classification/stat_scores.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 0785a7a419a..8ee276723c7 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional +from typing import Any, Optional import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import ( _binary_confusion_matrix_arg_validation, diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 724c23312ff..50bb6ed0837 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Literal, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.functional.classification.stat_scores import ( _binary_stat_scores_arg_validation, diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index c7e9eda5f8e..ec48f4b4839 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, Tuple +from typing import Optional, Tuple import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 6c1d1b6813a..028dea2bec3 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import Tensor, tensor +from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification from torchmetrics.utilities.data import _bincount, select_topk From 6e98c7f6922ee4d1ebc92b1c4af1d9b42e2ee5ef Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 14:12:37 +0200 Subject: [PATCH 55/74] custom movedim --- .../functional/classification/confusion_matrix.py | 6 +++--- src/torchmetrics/functional/classification/stat_scores.py | 4 ++-- src/torchmetrics/utilities/data.py | 7 +++++++ 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index ec48f4b4839..3e6f22959b4 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -18,7 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification -from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.data import _bincount, _movedim from torchmetrics.utilities.enums import DataType from torchmetrics.utilities.prints import rank_zero_warn @@ -482,8 +482,8 @@ def _multilabel_confusion_matrix_format( if not torch.all((0 <= preds) * (preds <= 1)): preds = preds.sigmoid() preds = preds > threshold - preds = preds.movedim(1, -1).reshape(-1, num_labels) - target = target.movedim(1, -1).reshape(-1, num_labels) + preds = _movedim(preds, 1, -1).reshape(-1, num_labels) + target = _movedim(target, 1, -1).reshape(-1, num_labels) if ignore_index is not None: preds = preds.clone() diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 028dea2bec3..95a44596aad 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -18,7 +18,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification -from torchmetrics.utilities.data import _bincount, select_topk +from torchmetrics.utilities.data import _bincount, _movedim, select_topk from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod from torchmetrics.utilities.prints import rank_zero_warn @@ -374,7 +374,7 @@ def _multiclass_stat_scores_update( target[idx] = num_classes if top_k > 1: - preds_oh = select_topk(preds, topk=top_k, dim=1).movedim(1, -1) + preds_oh = _movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) else: preds_oh = torch.nn.functional.one_hot( preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index c9a07e95c3b..c82c6592edf 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -269,3 +269,10 @@ def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: if tensor1.dtype != tensor2.dtype: tensor2 = tensor2.to(dtype=tensor1.dtype) return torch.allclose(tensor1, tensor2) + + +def _movedim(tensor: Tensor, dim1: int, dim2: int) -> tensor: + if _TORCH_GREATER_EQUAL_1_7: + return torch.movedim(tensor, dim1, dim2) + else: + return tensor.unsqueeze(dim2).transpose(dim2, dim1).squeeze() From debaa9124e0ab0cb8439c1942c865f2c7af3324b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 14:28:47 +0200 Subject: [PATCH 56/74] docs --- docs/source/classification/cohen_kappa.rst | 46 ++++++++++++++++ docs/source/classification/jaccard_index.rst | 46 ++++++++++++++++ .../classification/matthews_corr_coef.rst | 52 +++++++++++++++++-- .../classification/cohen_kappa.py | 33 ++++++++++++ src/torchmetrics/classification/jaccard.py | 6 +++ .../classification/matthews_corrcoef.py | 33 ++++++++++++ 6 files changed, 213 insertions(+), 3 deletions(-) diff --git a/docs/source/classification/cohen_kappa.rst b/docs/source/classification/cohen_kappa.rst index 41127fa4187..4a5abbd8094 100644 --- a/docs/source/classification/cohen_kappa.rst +++ b/docs/source/classification/cohen_kappa.rst @@ -12,11 +12,57 @@ Cohen Kappa Module Interface ________________ +CohenKappa +^^^^^^^^^^ + .. autoclass:: torchmetrics.CohenKappa :noindex: +BinaryCohenKappa +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryCohenKappa + :noindex: + :exclude-members: update, compute + +MulticlassCohenKappa +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassCohenKappa + :noindex: + :exclude-members: update, compute + +MultilabelCohenKappa +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelCohenKappa + :noindex: + :exclude-members: update, compute + + Functional Interface ____________________ +cohen_kappa +^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.cohen_kappa :noindex: + +binary_cohen_kappa +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_cohen_kappa + :noindex: + +multiclass_cohen_kappa +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_cohen_kappa + :noindex: + +multilabel_cohen_kappa +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_cohen_kappa + :noindex: diff --git a/docs/source/classification/jaccard_index.rst b/docs/source/classification/jaccard_index.rst index b37e264992d..e08cb5f3ba0 100644 --- a/docs/source/classification/jaccard_index.rst +++ b/docs/source/classification/jaccard_index.rst @@ -10,11 +10,57 @@ Jaccard Index Module Interface ________________ +CohenKappa +^^^^^^^^^^ + .. autoclass:: torchmetrics.JaccardIndex :noindex: +BinaryJaccardIndex +^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryJaccardIndex + :noindex: + :exclude-members: update, compute + +MulticlassJaccardIndex +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassJaccardIndex + :noindex: + :exclude-members: update, compute + +MultilabelJaccardIndex +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelJaccardIndex + :noindex: + :exclude-members: update, compute + + Functional Interface ____________________ +jaccard_index +^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.jaccard_index :noindex: + +binary_jaccard_index +^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_jaccard_index + :noindex: + +multiclass_jaccard_index +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_jaccard_index + :noindex: + +multilabel_jaccard_index +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_jaccard_index + :noindex: diff --git a/docs/source/classification/matthews_corr_coef.rst b/docs/source/classification/matthews_corr_coef.rst index 1cfe173c24c..569a209c722 100644 --- a/docs/source/classification/matthews_corr_coef.rst +++ b/docs/source/classification/matthews_corr_coef.rst @@ -5,18 +5,64 @@ .. include:: ../links.rst -#################### -Matthews Corr. Coef. -#################### +################################ +Matthews Correlation Coefficient +################################ Module Interface ________________ +MatthewsCorrCoef +^^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.MatthewsCorrCoef :noindex: +BinaryMatthewsCorrCoef +^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryMatthewsCorrCoef + :noindex: + :exclude-members: update, compute + +MulticlassMatthewsCorrCoef +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassMatthewsCorrCoef + :noindex: + :exclude-members: update, compute + +MultilabelMatthewsCorrCoef +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelMatthewsCorrCoef + :noindex: + :exclude-members: update, compute + + Functional Interface ____________________ +matthews_corrcoef +^^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.matthews_corrcoef :noindex: + +binary_matthews_corrcoef +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_matthews_corrcoef + :noindex: + +multiclass_matthews_corrcoef +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_matthews_corrcoef + :noindex: + +multilabel_matthews_corrcoef +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_matthews_corrcoef + :noindex: diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 3cfb344f3fe..57f433fb038 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -22,22 +22,55 @@ class BinaryCohenKappa(BinaryConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + def __init__(self, **kwargs): + pass + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + pass + + def compute(self) -> Tensor: + pass + class MulticlassCohenKappa(MulticlassConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + def __init__(self, **kwargs): + pass + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + pass + + def compute(self) -> Tensor: + pass + class MultilabelCohenKappa(MultilabelConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + def __init__(self, **kwargs): + pass + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + pass + + def compute(self) -> Tensor: + pass + # -------------------------- Old stuff -------------------------- diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 6a5755423f8..2bd464b6b9e 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -28,6 +28,8 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False @@ -50,6 +52,8 @@ def compute(self) -> Tensor: class MulticlassJaccardIndex(MulticlassConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False @@ -72,6 +76,8 @@ def compute(self) -> Tensor: class MultilabelJaccardIndex(MultilabelConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 242fcbce5f7..2198d934fac 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -25,22 +25,55 @@ class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + def __init__(self, **kwargs): + pass + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + pass + + def compute(self) -> Tensor: + pass + class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + def __init__(self, **kwargs): + pass + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + pass + + def compute(self) -> Tensor: + pass + class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): + """""" + is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False + def __init__(self, **kwargs): + pass + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + pass + + def compute(self) -> Tensor: + pass + # -------------------------- Old stuff -------------------------- From 5b3698805b8fa2e35049b4d96d1a1932e11fa364 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 30 Jun 2022 15:45:28 +0200 Subject: [PATCH 57/74] inits --- src/torchmetrics/__init__.py | 18 ++++++++++++++++ src/torchmetrics/classification/__init__.py | 21 ++++++++++++++++--- src/torchmetrics/functional/__init__.py | 21 ++++++++++++++++--- .../functional/classification/cohen_kappa.py | 13 ++++++++++++ .../functional/classification/jaccard.py | 11 ++++++++++ .../classification/matthews_corrcoef.py | 12 +++++++++++ 6 files changed, 90 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 1c9cff4c0ab..2e3bd45e9e0 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -28,6 +28,9 @@ AveragePrecision, BinaryConfusionMatrix, BinaryStatScores, + BinaryCohenKappa, + BinaryJaccardIndex, + BinaryMatthewsCorrCoef, BinnedAveragePrecision, BinnedPrecisionRecallCurve, BinnedRecallAtFixedPrecision, @@ -47,8 +50,14 @@ MatthewsCorrCoef, MulticlassConfusionMatrix, MulticlassStatScores, + MulticlassMatthewsCorrCoef, + MulticlassCohenKappa, + MulticlassJaccardIndex, MultilabelConfusionMatrix, MultilabelStatScores, + MultilabelCohenKappa, + MultilabelJaccardIndex, + MultilabelMatthewsCorrCoef, Precision, PrecisionRecallCurve, Recall, @@ -131,6 +140,9 @@ "CharErrorRate", "CHRFScore", "CohenKappa", + "BinaryCohenKappa", + "MulticlassCohenKappa", + "MultilabelCohenKappa", "ConfusionMatrix", "BinaryConfusionMatrix", "MulticlassConfusionMatrix", @@ -147,11 +159,17 @@ "HammingDistance", "HingeLoss", "JaccardIndex", + "BinaryJaccardIndex", + "MulticlassJaccardIndex", + "MultilabelJaccardIndex", "KLDivergence", "LabelRankingAveragePrecision", "LabelRankingLoss", "MatchErrorRate", "MatthewsCorrCoef", + "BinaryMatthewsCorrCoef", + "MulticlassMatthewsCorrCoef", + "MultilabelMatthewsCorrCoef", "MaxMetric", "MeanAbsoluteError", "MeanAbsolutePercentageError", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 2575b9f9ba2..8bc740d1a3b 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -19,7 +19,12 @@ from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 -from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 +from torchmetrics.classification.cohen_kappa import ( # noqa: F401 + BinaryCohenKappa, + CohenKappa, + MulticlassCohenKappa, + MultilabelCohenKappa, +) from torchmetrics.classification.confusion_matrix import ( # noqa: F401 BinaryConfusionMatrix, ConfusionMatrix, @@ -30,9 +35,19 @@ from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 from torchmetrics.classification.hinge import HingeLoss # noqa: F401 -from torchmetrics.classification.jaccard import JaccardIndex # noqa: F401 +from torchmetrics.classification.jaccard import ( # noqa: F401 + BinaryJaccardIndex, + JaccardIndex, + MulticlassJaccardIndex, + MultilabelJaccardIndex, +) from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401 -from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef # noqa: F401 +from torchmetrics.classification.matthews_corrcoef import ( # noqa: F401 + BinaryMatthewsCorrCoef, + MatthewsCorrCoef, + MulticlassMatthewsCorrCoef, + MultilabelMatthewsCorrCoef, +) from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401 from torchmetrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from torchmetrics.classification.ranking import ( # noqa: F401 diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 7737c429833..a52178e9479 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -19,7 +19,12 @@ from torchmetrics.functional.classification.auroc import auroc from torchmetrics.functional.classification.average_precision import average_precision from torchmetrics.functional.classification.calibration_error import calibration_error -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa +from torchmetrics.functional.classification.cohen_kappa import ( + binary_cohen_kappa, + cohen_kappa, + multiclass_cohen_kappa, + multilabel_cohen_kappa, +) from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, confusion_matrix, @@ -30,9 +35,19 @@ from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score from torchmetrics.functional.classification.hamming import hamming_distance from torchmetrics.functional.classification.hinge import hinge_loss -from torchmetrics.functional.classification.jaccard import jaccard_index +from torchmetrics.functional.classification.jaccard import ( + binary_jaccard_index, + jaccard_index, + multiclass_jaccard_index, + multilabel_jaccard_index, +) from torchmetrics.functional.classification.kl_divergence import kl_divergence -from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef +from torchmetrics.functional.classification.matthews_corrcoef import ( + binary_matthews_corrcoef, + matthews_corrcoef, + multiclass_matthews_corrcoef, + multilabel_matthews_corrcoef +) from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve from torchmetrics.functional.classification.ranking import ( diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 01623e08fab..d983a91a4b4 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -18,6 +18,19 @@ from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update + +def binary_cohen_kappa(): + pass + +def multiclass_cohen_kappa(): + pass + +def multilabel_cohen_kappa(): + pass + + +# -------------------------- Old stuff -------------------------- + _cohen_kappa_update = _confusion_matrix_update diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 4f9b9c2400e..22557bf2429 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -18,6 +18,17 @@ from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update +def binary_jaccard_index(): + pass + +def multiclass_jaccard_index(): + pass + +def multilabel_jaccard_index(): + pass + + +# -------------------------- Old stuff -------------------------- def _jaccard_from_confmat( confmat: Tensor, diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index d7956ffab6d..fa30fedfbef 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -16,6 +16,18 @@ from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update + +def binary_matthews_corrcoef(): + pass + +def multiclass_matthews_corrcoef(): + pass + +def multilabel_matthews_corrcoef(): + pass + +# -------------------------- Old stuff -------------------------- + _matthews_corrcoef_update = _confusion_matrix_update From 141eb02a985c96e5e40f4a5bdc11ae886442c52a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 13 Jul 2022 11:54:56 +0200 Subject: [PATCH 58/74] fix --- src/torchmetrics/__init__.py | 14 +++++++------- src/torchmetrics/functional/__init__.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 2e3bd45e9e0..e9c7687614e 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -26,11 +26,11 @@ ROC, Accuracy, AveragePrecision, - BinaryConfusionMatrix, - BinaryStatScores, BinaryCohenKappa, + BinaryConfusionMatrix, BinaryJaccardIndex, BinaryMatthewsCorrCoef, + BinaryStatScores, BinnedAveragePrecision, BinnedPrecisionRecallCurve, BinnedRecallAtFixedPrecision, @@ -48,16 +48,16 @@ LabelRankingAveragePrecision, LabelRankingLoss, MatthewsCorrCoef, - MulticlassConfusionMatrix, - MulticlassStatScores, - MulticlassMatthewsCorrCoef, MulticlassCohenKappa, + MulticlassConfusionMatrix, MulticlassJaccardIndex, - MultilabelConfusionMatrix, - MultilabelStatScores, + MulticlassMatthewsCorrCoef, + MulticlassStatScores, MultilabelCohenKappa, + MultilabelConfusionMatrix, MultilabelJaccardIndex, MultilabelMatthewsCorrCoef, + MultilabelStatScores, Precision, PrecisionRecallCurve, Recall, diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index ae576ebb5fb..ddc278c281d 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -46,7 +46,7 @@ binary_matthews_corrcoef, matthews_corrcoef, multiclass_matthews_corrcoef, - multilabel_matthews_corrcoef + multilabel_matthews_corrcoef, ) from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall from torchmetrics.functional.classification.precision_recall_curve import precision_recall_curve From 00c18e6fa3bfe737b124a7c4c59b20fa6a6aa762 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Jul 2022 09:58:23 +0000 Subject: [PATCH 59/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- speed.py | 8 +++++--- src/torchmetrics/functional/classification/cohen_kappa.py | 2 ++ src/torchmetrics/functional/classification/jaccard.py | 4 ++++ .../functional/classification/matthews_corrcoef.py | 3 +++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/speed.py b/speed.py index 0a30f2e3028..d8ab77804fd 100644 --- a/speed.py +++ b/speed.py @@ -1,8 +1,10 @@ -from torchmetrics.functional import confusion_matrix, multiclass_confusion_matrix, stat_scores, multiclass_stat_scores -from time import perf_counter -import torch from functools import partial from statistics import mean, stdev +from time import perf_counter + +import torch + +from torchmetrics.functional import confusion_matrix, multiclass_confusion_matrix, multiclass_stat_scores, stat_scores OUTER_REPS = 5 INNER_REPS = 1000 diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index d983a91a4b4..bcc0f838457 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -22,9 +22,11 @@ def binary_cohen_kappa(): pass + def multiclass_cohen_kappa(): pass + def multilabel_cohen_kappa(): pass diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 22557bf2429..2d89088522b 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -18,18 +18,22 @@ from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update + def binary_jaccard_index(): pass + def multiclass_jaccard_index(): pass + def multilabel_jaccard_index(): pass # -------------------------- Old stuff -------------------------- + def _jaccard_from_confmat( confmat: Tensor, num_classes: int, diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index fa30fedfbef..29a83f3b06c 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -20,12 +20,15 @@ def binary_matthews_corrcoef(): pass + def multiclass_matthews_corrcoef(): pass + def multilabel_matthews_corrcoef(): pass + # -------------------------- Old stuff -------------------------- _matthews_corrcoef_update = _confusion_matrix_update From d04d9851f97bcdfc3df09a036a1bd663b4b63acd Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 13 Jul 2022 13:00:33 +0200 Subject: [PATCH 60/74] fix mistakes --- src/torchmetrics/classification/confusion_matrix.py | 3 +++ src/torchmetrics/classification/stat_scores.py | 3 +++ .../classification/test_confusion_matrix.py | 12 ++++++------ tests/unittests/classification/test_stat_scores.py | 12 ++++++------ 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index 8ee276723c7..aadb174305e 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -64,6 +64,7 @@ class BinaryConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torchmetrics import BinaryConfusionMatrix @@ -151,6 +152,7 @@ class MulticlassConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (pred is integer tensor): >>> from torchmetrics import MulticlassConfusionMatrix @@ -246,6 +248,7 @@ class MultilabelConfusionMatrix(Metric): - ``'all'``: normalization over the whole matrix validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torchmetrics import MultilabelConfusionMatrix diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index 50bb6ed0837..f82920437c6 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -106,6 +106,7 @@ class BinaryStatScores(_AbstractStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torchmetrics.functional import binary_stat_scores @@ -222,6 +223,7 @@ class MulticlassStatScores(_AbstractStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torchmetrics.functional import multiclass_stat_scores @@ -364,6 +366,7 @@ class MultilabelStatScores(_AbstractStatScores): Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example (preds is int tensor): >>> from torchmetrics.functional import multilabel_stat_scores diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index b52ba3a216f..87ae1c66c95 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -102,7 +102,7 @@ def test_binary_confusion_matrix_differentiability(self, input): ) @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_binary_confusion_matrix_half_cpu(self, input, dtype): + def test_binary_confusion_matrix_dtype_cpu(self, input, dtype): preds, target = input if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: pytest.xfail(reason="half support of core ops not support before pytorch v1.6") @@ -119,7 +119,7 @@ def test_binary_confusion_matrix_half_cpu(self, input, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_binary_confusion_matrix_half_gpu(self, input, dtype): + def test_binary_confusion_matrix_dtype_gpu(self, input, dtype): preds, target = input self.run_precision_test_gpu( preds=preds, @@ -197,7 +197,7 @@ def test_multiclass_confusion_matrix_differentiability(self, input): ) @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multiclass_confusion_matrix_half_cpu(self, input, dtype): + def test_multiclass_confusion_matrix_dtype_cpu(self, input, dtype): preds, target = input if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: pytest.xfail(reason="half support of core ops not support before pytorch v1.6") @@ -212,7 +212,7 @@ def test_multiclass_confusion_matrix_half_cpu(self, input, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multiclass_confusion_matrix_half_gpu(self, input, dtype): + def test_multiclass_confusion_matrix_dtype_gpu(self, input, dtype): preds, target = input self.run_precision_test_gpu( preds=preds, @@ -295,7 +295,7 @@ def test_multilabel_confusion_matrix_differentiability(self, input): ) @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multilabel_confusion_matrix_half_cpu(self, input, dtype): + def test_multilabel_confusion_matrix_dtype_cpu(self, input, dtype): preds, target = input if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: pytest.xfail(reason="half support of core ops not support before pytorch v1.6") @@ -312,7 +312,7 @@ def test_multilabel_confusion_matrix_half_cpu(self, input, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multilabel_confusion_matrix_half_gpu(self, input, dtype): + def test_multilabel_confusion_matrix_dtype_gpu(self, input, dtype): preds, target = input self.run_precision_test_gpu( preds=preds, diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index d61284834a7..8f30f408fc1 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -122,7 +122,7 @@ def test_binary_stat_scores_differentiability(self, input): ) @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_binary_stat_scores_half_cpu(self, input, dtype): + def test_binary_stat_scores_dtype_cpu(self, input, dtype): preds, target = input if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: pytest.xfail(reason="half support of core ops not support before pytorch v1.6") @@ -139,7 +139,7 @@ def test_binary_stat_scores_half_cpu(self, input, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_binary_stat_scores_half_gpu(self, input, dtype): + def test_binary_stat_scores_dtype_gpu(self, input, dtype): preds, target = input self.run_precision_test_gpu( preds=preds, @@ -283,7 +283,7 @@ def test_multiclass_stat_scores_differentiability(self, input): ) @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multiclass_stat_scores_half_cpu(self, input, dtype): + def test_multiclass_stat_scores_dtype_cpu(self, input, dtype): preds, target = input if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: pytest.xfail(reason="half support of core ops not support before pytorch v1.6") @@ -300,7 +300,7 @@ def test_multiclass_stat_scores_half_cpu(self, input, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multiclass_stat_scores_half_gpu(self, input, dtype): + def test_multiclass_stat_scores_dtype_gpu(self, input, dtype): preds, target = input self.run_precision_test_gpu( preds=preds, @@ -466,7 +466,7 @@ def test_multilabel_stat_scores_differentiability(self, input): ) @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multilabel_stat_scores_half_cpu(self, input, dtype): + def test_multilabel_stat_scores_dtype_cpu(self, input, dtype): preds, target = input if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: pytest.xfail(reason="half support of core ops not support before pytorch v1.6") @@ -483,7 +483,7 @@ def test_multilabel_stat_scores_half_cpu(self, input, dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") @pytest.mark.parametrize("dtype", [torch.half, torch.double]) - def test_multilabel_stat_scores_half_gpu(self, input, dtype): + def test_multilabel_stat_scores_dtype_gpu(self, input, dtype): preds, target = input self.run_precision_test_gpu( preds=preds, From c29035366fa347d452e0b93e1f33b6ff0bc3173d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 13 Jul 2022 13:01:21 +0200 Subject: [PATCH 61/74] working cohen kappa --- docs/source/classification/cohen_kappa.rst | 14 - src/torchmetrics/__init__.py | 2 - src/torchmetrics/classification/__init__.py | 11 +- .../classification/cohen_kappa.py | 169 ++++++++-- src/torchmetrics/classification/jaccard.py | 4 - src/torchmetrics/functional/__init__.py | 1 - .../functional/classification/cohen_kappa.py | 109 ++++++- .../classification/test_cohen_kappa.py | 290 ++++++++++++------ 8 files changed, 446 insertions(+), 154 deletions(-) diff --git a/docs/source/classification/cohen_kappa.rst b/docs/source/classification/cohen_kappa.rst index 4a5abbd8094..e5617796659 100644 --- a/docs/source/classification/cohen_kappa.rst +++ b/docs/source/classification/cohen_kappa.rst @@ -32,14 +32,6 @@ MulticlassCohenKappa :noindex: :exclude-members: update, compute -MultilabelCohenKappa -^^^^^^^^^^^^^^^^^^^^ - -.. autoclass:: torchmetrics.MultilabelCohenKappa - :noindex: - :exclude-members: update, compute - - Functional Interface ____________________ @@ -60,9 +52,3 @@ multiclass_cohen_kappa .. autofunction:: torchmetrics.functional.multiclass_cohen_kappa :noindex: - -multilabel_cohen_kappa -^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: torchmetrics.functional.multilabel_cohen_kappa - :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index e9c7687614e..16c1f9e788a 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -53,7 +53,6 @@ MulticlassJaccardIndex, MulticlassMatthewsCorrCoef, MulticlassStatScores, - MultilabelCohenKappa, MultilabelConfusionMatrix, MultilabelJaccardIndex, MultilabelMatthewsCorrCoef, @@ -142,7 +141,6 @@ "CohenKappa", "BinaryCohenKappa", "MulticlassCohenKappa", - "MultilabelCohenKappa", "ConfusionMatrix", "BinaryConfusionMatrix", "MulticlassConfusionMatrix", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 8bc740d1a3b..ca4209909bb 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -19,18 +19,17 @@ from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 -from torchmetrics.classification.cohen_kappa import ( # noqa: F401 - BinaryCohenKappa, - CohenKappa, - MulticlassCohenKappa, - MultilabelCohenKappa, -) from torchmetrics.classification.confusion_matrix import ( # noqa: F401 BinaryConfusionMatrix, ConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix, ) +from torchmetrics.classification.cohen_kappa import ( # noqa: F401 + BinaryCohenKappa, + CohenKappa, + MulticlassCohenKappa, +) from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 57f433fb038..b8eb6c79e8b 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -15,61 +15,174 @@ import torch from torch import Tensor - -from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix -from torchmetrics.functional.classification.cohen_kappa import _cohen_kappa_compute, _cohen_kappa_update +from typing_extensions import Literal + +from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix +from torchmetrics.functional.classification.cohen_kappa import ( + _cohen_kappa_compute, + _cohen_kappa_update, + _cohen_kappa_reduce, + _binary_cohen_kappa_arg_validation, + _multiclass_cohen_kappa_arg_validation, +) from torchmetrics.metric import Metric class BinaryCohenKappa(BinaryConfusionMatrix): - """""" + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for binary + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics import BinaryCohenKappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryCohenKappa() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics import BinaryCohenKappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryCohenKappa() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + + """ is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False - def __init__(self, **kwargs): - pass - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - pass + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(threshold, ignore_index, normalize=None, validate_args=False, **kwargs) + if validate_args: + _binary_cohen_kappa_arg_validation(threshold, ignore_index, weights) + self.weights = weights + self.validate_args = validate_args def compute(self) -> Tensor: - pass + return _cohen_kappa_reduce(self.confmat, self.weights) class MulticlassCohenKappa(MulticlassConfusionMatrix): - """""" + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass + tasks. It is defined as - is_differentiable: bool = False - higher_is_better: bool = True - full_state_update: bool = False + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) - def __init__(self, **kwargs): - pass + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - pass + Accepts the following input tensors: - def compute(self) -> Tensor: - pass + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + Additional dimension ``...`` will be flattened into the batch dimension. -class MultilabelCohenKappa(MultilabelConfusionMatrix): - """""" + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + weights: Weighting type to calculate the score. Choose from: + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics import MulticlassCohenKappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassCohenKappa(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics import MulticlassCohenKappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassCohenKappa(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + """ is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False - def __init__(self, **kwargs): - pass - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - pass + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(num_classes, ignore_index, normalize=None, validate_args=False, **kwargs) + if validate_args: + _multiclass_cohen_kappa_arg_validation(num_classes, ignore_index, weights) + self.weights = weights + self.validate_args = validate_args def compute(self) -> Tensor: - pass + return _cohen_kappa_reduce(self.confmat, self.weights) # -------------------------- Old stuff -------------------------- diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 023c09ad6d6..231f2063629 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -19,11 +19,7 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.jaccard import ( - _binary_jaccard_index_compute, - _binary_jaccard_index_validate_args, _jaccard_from_confmat, - _multiclass_jaccard_index_compute, - _multilabel_jaccard_index_compute, ) diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index ddc278c281d..ab107fd144f 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -23,7 +23,6 @@ binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa, - multilabel_cohen_kappa, ) from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index d983a91a4b4..090a99b20b1 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -15,18 +15,113 @@ import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import ( + _confusion_matrix_compute, + _confusion_matrix_update, + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_update +) + + +def _cohen_kappa_reduce(confmat: Tensor, weights: Optional[Literal["linear", "quadratic", "none"]] = None) -> Tensor: + """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the cohen kappa score.""" + confmat = confmat.float() if not confmat.is_floating_point() else confmat + n_classes = confmat.shape[0] + sum0 = confmat.sum(dim=0, keepdim=True) + sum1 = confmat.sum(dim=1, keepdim=True) + expected = sum1 @ sum0 / sum0.sum() # outer product -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update + if weights is None or weights == "none": + w_mat = torch.ones_like(confmat).flatten() + w_mat[:: n_classes + 1] = 0 + w_mat = w_mat.reshape(n_classes, n_classes) + elif weights in ("linear", "quadratic"): + w_mat = torch.zeros_like(confmat) + w_mat += torch.arange(n_classes, dtype=w_mat.dtype, device=w_mat.device) + if weights == "linear": + w_mat = torch.abs(w_mat - w_mat.T) + else: + w_mat = torch.pow(w_mat - w_mat.T, 2.0) + else: + raise ValueError( + f"Received {weights} for argument ``weights`` but should be either" " None, 'linear' or 'quadratic'" + ) + k = torch.sum(w_mat * confmat) / torch.sum(w_mat * expected) + return 1 - k + + +def _binary_cohen_kappa_arg_validation( + threshold: float = 0.5, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``weights`` has to be "linear" | "quadratic" | "none" | None + """ + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize=None) + allowed_weights = ("linear", "quadratic", "none", None) + if weights not in allowed_weights: + raise ValueError(f"Expected argument `weight` to be one of {allowed_weights}, but got {weights}.") + + +def binary_cohen_kappa( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_cohen_kappa_arg_validation(threshold, ignore_index, weights) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _cohen_kappa_reduce(confmat, weights) -def binary_cohen_kappa(): - pass +def _multiclass_cohen_kappa_arg_validation( + num_classes: int, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``ignore_index`` has to be None or int + - ``weights`` has to be "linear" | "quadratic" | "none" | None + """ + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize=None) + allowed_weights = ("linear", "quadratic", "none", None) + if weights not in allowed_weights: + raise ValueError(f"Expected argument `weight` to be one of {allowed_weights}, but got {weights}.") -def multiclass_cohen_kappa(): - pass -def multilabel_cohen_kappa(): - pass +def multiclass_cohen_kappa( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + weights: Optional[Literal["linear", "quadratic", "none"]] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_cohen_kappa_arg_validation(num_classes, ignore_index, weights) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _cohen_kappa_reduce(confmat, weights) # -------------------------- Old stuff -------------------------- diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index d0a2b1c9c9c..c68cd3ef9bc 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -4,130 +4,236 @@ import pytest import torch from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa - -from torchmetrics.classification.cohen_kappa import CohenKappa -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from scipy.special import expit as sigmoid +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from torchmetrics.classification.cohen_kappa import CohenKappa, BinaryCohenKappa, MulticlassCohenKappa +from torchmetrics.functional.classification.cohen_kappa import cohen_kappa, binary_cohen_kappa, multiclass_cohen_kappa +from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) -def _sk_cohen_kappa_binary_prob(preds, target, weights=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_cohen_kappa_binary(preds, target, weights=None, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_cohen_kappa(y1=target, y2=preds, weights=weights) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryConfusionMatrix(MetricTester): + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_cohen_kappa(self, input, ddp, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryCohenKappa, + sk_metric=partial(_sk_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "weights": weights, + "ignore_index": ignore_index, + }, + ) - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_confusion_matrix_functional(self, input, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_cohen_kappa, + sk_metric=partial(_sk_cohen_kappa_binary, weights=weights, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "weights": weights, + "ignore_index": ignore_index, + }, + ) + def test_binary_cohen_kappa_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryCohenKappa, + metric_functional=binary_cohen_kappa, + metric_args={"threshold": THRESHOLD}, + ) -def _sk_cohen_kappa_binary(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_cohen_kappa_dtypes_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryCohenKappa, + metric_functional=binary_cohen_kappa, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_dtypes_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryCohenKappa, + metric_functional=binary_cohen_kappa, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_cohen_kappa_multilabel_prob(preds, target, weights=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() +def _sk_cohen_kappa_multiclass(preds, target, weights, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_cohen_kappa(y1=target, y2=preds, weights=weights) -def _sk_cohen_kappa_multilabel(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +# -------------------------- Old stuff -------------------------- - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +# def _sk_cohen_kappa_binary_prob(preds, target, weights=None): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) -def _sk_cohen_kappa_multiclass_prob(preds, target, weights=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +# def _sk_cohen_kappa_binary(preds, target, weights=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) -def _sk_cohen_kappa_multiclass(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +# def _sk_cohen_kappa_multilabel_prob(preds, target, weights=None): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) -def _sk_cohen_kappa_multidim_multiclass_prob(preds, target, weights=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +# def _sk_cohen_kappa_multilabel(preds, target, weights=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) -def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) +# def _sk_cohen_kappa_multiclass_prob(preds, target, weights=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() +# sk_target = target.view(-1).numpy() +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) -@pytest.mark.parametrize("weights", ["linear", "quadratic", None]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_cohen_kappa_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_cohen_kappa_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cohen_kappa_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_cohen_kappa_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cohen_kappa_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_cohen_kappa_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cohen_kappa_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_cohen_kappa_multidim_multiclass, NUM_CLASSES), - ], -) -class TestCohenKappa(MetricTester): - atol = 1e-5 - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=CohenKappa, - sk_metric=partial(sk_metric, weights=weights), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, - ) +# def _sk_cohen_kappa_multiclass(preds, target, weights=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() - def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes): - self.run_functional_metric_test( - preds, - target, - metric_functional=cohen_kappa, - sk_metric=partial(sk_metric, weights=weights), - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, - ) +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) - def test_cohen_kappa_differentiability(self, preds, target, sk_metric, weights, num_classes): - self.run_differentiability_test( - preds=preds, - target=target, - metric_module=CohenKappa, - metric_functional=cohen_kappa, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, - ) + +# def _sk_cohen_kappa_multidim_multiclass_prob(preds, target, weights=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + + +# def _sk_cohen_kappa_multidim_multiclass(preds, target, weights=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_cohen_kappa(y1=sk_target, y2=sk_preds, weights=weights) + + +# @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_cohen_kappa_binary_prob, 2), +# (_input_binary.preds, _input_binary.target, _sk_cohen_kappa_binary, 2), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cohen_kappa_multilabel_prob, 2), +# (_input_mlb.preds, _input_mlb.target, _sk_cohen_kappa_multilabel, 2), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cohen_kappa_multiclass_prob, NUM_CLASSES), +# (_input_mcls.preds, _input_mcls.target, _sk_cohen_kappa_multiclass, NUM_CLASSES), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cohen_kappa_multidim_multiclass_prob, NUM_CLASSES), +# (_input_mdmc.preds, _input_mdmc.target, _sk_cohen_kappa_multidim_multiclass, NUM_CLASSES), +# ], +# ) +# class TestCohenKappa(MetricTester): +# atol = 1e-5 + +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_cohen_kappa(self, weights, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=CohenKappa, +# sk_metric=partial(sk_metric, weights=weights), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, +# ) + +# def test_cohen_kappa_functional(self, weights, preds, target, sk_metric, num_classes): +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=cohen_kappa, +# sk_metric=partial(sk_metric, weights=weights), +# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, +# ) + +# def test_cohen_kappa_differentiability(self, preds, target, sk_metric, weights, num_classes): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=CohenKappa, +# metric_functional=cohen_kappa, +# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "weights": weights}, +# ) -def test_warning_on_wrong_weights(tmpdir): - preds = torch.randint(3, size=(20,)) - target = torch.randint(3, size=(20,)) +# def test_warning_on_wrong_weights(tmpdir): +# preds = torch.randint(3, size=(20,)) +# target = torch.randint(3, size=(20,)) - with pytest.raises(ValueError, match=".* ``weights`` but should be either None, 'linear' or 'quadratic'"): - cohen_kappa(preds, target, num_classes=3, weights="unknown_arg") +# with pytest.raises(ValueError, match=".* ``weights`` but should be either None, 'linear' or 'quadratic'"): +# cohen_kappa(preds, target, num_classes=3, weights="unknown_arg") From 97acff4d897200ddc8eaf8befa55925fce771a6a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Jul 2022 11:03:37 +0000 Subject: [PATCH 62/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- speed.py | 8 +++++--- src/torchmetrics/classification/__init__.py | 6 +----- src/torchmetrics/classification/cohen_kappa.py | 6 +++--- src/torchmetrics/classification/jaccard.py | 4 +--- src/torchmetrics/functional/__init__.py | 6 +----- .../functional/classification/cohen_kappa.py | 10 +++++----- tests/unittests/classification/test_cohen_kappa.py | 7 ++++--- 7 files changed, 20 insertions(+), 27 deletions(-) diff --git a/speed.py b/speed.py index 0a30f2e3028..d8ab77804fd 100644 --- a/speed.py +++ b/speed.py @@ -1,8 +1,10 @@ -from torchmetrics.functional import confusion_matrix, multiclass_confusion_matrix, stat_scores, multiclass_stat_scores -from time import perf_counter -import torch from functools import partial from statistics import mean, stdev +from time import perf_counter + +import torch + +from torchmetrics.functional import confusion_matrix, multiclass_confusion_matrix, multiclass_stat_scores, stat_scores OUTER_REPS = 5 INNER_REPS = 1000 diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index ca4209909bb..9ec9648a58b 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -19,17 +19,13 @@ from torchmetrics.classification.binned_precision_recall import BinnedPrecisionRecallCurve # noqa: F401 from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 +from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401 from torchmetrics.classification.confusion_matrix import ( # noqa: F401 BinaryConfusionMatrix, ConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix, ) -from torchmetrics.classification.cohen_kappa import ( # noqa: F401 - BinaryCohenKappa, - CohenKappa, - MulticlassCohenKappa, -) from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index b8eb6c79e8b..7e5dd86d7a0 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -19,10 +19,10 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import ( - _cohen_kappa_compute, - _cohen_kappa_update, - _cohen_kappa_reduce, _binary_cohen_kappa_arg_validation, + _cohen_kappa_compute, + _cohen_kappa_reduce, + _cohen_kappa_update, _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 231f2063629..8c08582f8fd 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -18,9 +18,7 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.jaccard import ( - _jaccard_from_confmat, -) +from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat class BinaryJaccardIndex(BinaryConfusionMatrix): diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index ab107fd144f..294c3b97ca7 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -19,11 +19,7 @@ from torchmetrics.functional.classification.auroc import auroc from torchmetrics.functional.classification.average_precision import average_precision from torchmetrics.functional.classification.calibration_error import calibration_error -from torchmetrics.functional.classification.cohen_kappa import ( - binary_cohen_kappa, - cohen_kappa, - multiclass_cohen_kappa, -) +from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, confusion_matrix, diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 090a99b20b1..bebbbfe1400 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -18,16 +18,16 @@ from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import ( - _confusion_matrix_compute, - _confusion_matrix_update, _binary_confusion_matrix_arg_validation, - _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, - _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_format, - _multiclass_confusion_matrix_update + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, ) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index c68cd3ef9bc..38bbd961d2c 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -3,11 +3,12 @@ import numpy as np import pytest import torch -from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa from scipy.special import expit as sigmoid +from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa + +from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa +from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -from torchmetrics.classification.cohen_kappa import CohenKappa, BinaryCohenKappa, MulticlassCohenKappa -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa, binary_cohen_kappa, multiclass_cohen_kappa from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index From 4d3f1d639ef05fe7da6cf2f1c8ccc3c1ee10b302 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 13 Jul 2022 13:04:55 +0200 Subject: [PATCH 63/74] Delete speed.py --- speed.py | 38 -------------------------------------- 1 file changed, 38 deletions(-) delete mode 100644 speed.py diff --git a/speed.py b/speed.py deleted file mode 100644 index d8ab77804fd..00000000000 --- a/speed.py +++ /dev/null @@ -1,38 +0,0 @@ -from functools import partial -from statistics import mean, stdev -from time import perf_counter - -import torch - -from torchmetrics.functional import confusion_matrix, multiclass_confusion_matrix, multiclass_stat_scores, stat_scores - -OUTER_REPS = 5 -INNER_REPS = 1000 - -preds = torch.randn(100, 10).softmax(dim=-1) -target = torch.randint(10, (100,)) - - -def time(metric_fn, name, base=None): - timings = [] - for _ in range(OUTER_REPS): - start = perf_counter() - for _ in range(INNER_REPS): - metric_fn(preds, target) - end = perf_counter() - timings.append(end - start) - extra = f", speedup: {base / mean(timings)}" if base is not None else "" - print(f"{name.ljust(15)}: {mean(timings):0.3E} +- {stdev(timings):0.3E}{extra}") - return mean(timings) - - -print(f"\nExperiments running {INNER_REPS} calculations, repeting {OUTER_REPS} times:") -print("\nMulticlass Confusion matrix") -base = time(partial(confusion_matrix, num_classes=10), name="Old") -time(partial(multiclass_confusion_matrix, num_classes=10), name="New with IV", base=base) -time(partial(multiclass_confusion_matrix, num_classes=10, validate_args=False), name="New without IV", base=base) - -print("\nMulticlass Stat Scores") -base = time(partial(stat_scores, num_classes=10), name="Old") -time(partial(multiclass_stat_scores, num_classes=10), name="New with IV", base=base) -time(partial(multiclass_stat_scores, num_classes=10, validate_args=False), name="New without IV", base=base) From 57e32c450a4bccbfabdd6668bd137bcf8010eef7 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 13 Jul 2022 15:01:08 +0200 Subject: [PATCH 64/74] cohen kappa done --- .../classification/cohen_kappa.py | 16 +-- src/torchmetrics/functional/__init__.py | 8 ++ .../functional/classification/cohen_kappa.py | 115 +++++++++++++++++- .../classification/test_cohen_kappa.py | 104 +++++++++++++++- 4 files changed, 224 insertions(+), 19 deletions(-) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index b8eb6c79e8b..e68588179cf 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -19,7 +19,7 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import ( - _cohen_kappa_compute, + _cohen_kappa_compute, _cohen_kappa_update, _cohen_kappa_reduce, _binary_cohen_kappa_arg_validation, @@ -68,8 +68,7 @@ class labels. >>> preds = torch.tensor([0, 1, 0, 0]) >>> metric = BinaryCohenKappa() >>> metric(preds, target) - tensor([[2, 0], - [1, 1]]) + tensor(0.5000) Example (preds is float tensor): >>> from torchmetrics import BinaryCohenKappa @@ -77,8 +76,7 @@ class labels. >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) >>> metric = BinaryCohenKappa() >>> metric(preds, target) - tensor([[2, 0], - [1, 1]]) + tensor(0.5000) """ is_differentiable: bool = False @@ -143,9 +141,7 @@ class labels. >>> preds = torch.tensor([2, 1, 0, 1]) >>> metric = MulticlassCohenKappa(num_classes=3) >>> metric(preds, target) - tensor([[1, 1, 0], - [0, 1, 0], - [0, 0, 1]]) + tensor(0.6364) Example (pred is float tensor): >>> from torchmetrics import MulticlassCohenKappa @@ -158,9 +154,7 @@ class labels. ... ]) >>> metric = MulticlassCohenKappa(num_classes=3) >>> metric(preds, target) - tensor([[1, 1, 0], - [0, 1, 0], - [0, 0, 1]]) + tensor(0.6364) """ is_differentiable: bool = False diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index ab107fd144f..f571a3de833 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -202,4 +202,12 @@ "binary_stat_scores", "multiclass_stat_scores", "multilabel_stat_scores", + "binary_cohen_kappa", + "multiclass_cohen_kappa", + "binary_jaccard_index", + "multiclass_jaccard_index", + "multilabel_jaccard_index", + "binary_matthews_corrcoef", + "multiclass_matthews_corrcoef", + "multilabel_matthews_corrcoef", ] diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 090a99b20b1..10ff193fdf4 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -18,16 +18,16 @@ from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import ( - _confusion_matrix_compute, - _confusion_matrix_update, _binary_confusion_matrix_arg_validation, - _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, - _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_format, - _multiclass_confusion_matrix_update + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, ) @@ -83,6 +83,56 @@ def binary_cohen_kappa( weights: Optional[Literal["linear", "quadratic", "none"]] = None, validate_args: bool = True, ) -> Tensor: + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for binary + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_cohen_kappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_cohen_kappa(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_cohen_kappa + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_cohen_kappa(preds, target) + tensor(0.5000) + + """ if validate_args: _binary_cohen_kappa_arg_validation(threshold, ignore_index, weights) _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) @@ -116,6 +166,61 @@ def multiclass_cohen_kappa( weights: Optional[Literal["linear", "quadratic", "none"]] = None, validate_args: bool = True, ) -> Tensor: + r"""Calculates `Cohen's kappa score`_ that measures inter-annotator agreement for multiclass + tasks. It is defined as + + .. math:: + \kappa = (p_o - p_e) / (1 - p_e) + + where :math:`p_o` is the empirical probability of agreement and :math:`p_e` is + the expected agreement when both annotators assign labels randomly. Note that + :math:`p_e` is estimated using a per-annotator empirical prior over the + class labels. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + weights: Weighting type to calculate the score. Choose from: + + - ``None`` or ``'none'``: no weighting + - ``'linear'``: linear weighting + - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.functional import multiclass_cohen_kappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_cohen_kappa(preds, target, num_classes=3) + tensor(0.6364) + + Example (pred is float tensor): + >>> from torchmetrics.functional import multiclass_cohen_kappa + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_cohen_kappa(preds, target, num_classes=3) + tensor(0.6364) + + """ if validate_args: _multiclass_cohen_kappa_arg_validation(num_classes, ignore_index, weights) _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index c68cd3ef9bc..cf185835deb 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -1,13 +1,27 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from functools import partial import numpy as np import pytest import torch -from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa from scipy.special import expit as sigmoid +from sklearn.metrics import cohen_kappa_score as sk_cohen_kappa + +from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, MulticlassCohenKappa +from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, multiclass_cohen_kappa from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 -from torchmetrics.classification.cohen_kappa import CohenKappa, BinaryCohenKappa, MulticlassCohenKappa -from torchmetrics.functional.classification.cohen_kappa import cohen_kappa, binary_cohen_kappa, multiclass_cohen_kappa from unittests.classification.inputs import _binary_cases, _multiclass_cases from unittests.helpers import seed_all from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index @@ -31,6 +45,8 @@ def _sk_cohen_kappa_binary(preds, target, weights=None, ignore_index=None): @pytest.mark.parametrize("input", _binary_cases) class TestBinaryConfusionMatrix(MetricTester): + atol = 1e-5 + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) @@ -124,6 +140,88 @@ def _sk_cohen_kappa_multiclass(preds, target, weights, ignore_index=None): return sk_cohen_kappa(y1=target, y2=preds, weights=weights) +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassConfusionMatrix(MetricTester): + atol = 1e-5 + + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_cohen_kappa(self, input, ddp, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassCohenKappa, + sk_metric=partial(_sk_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "weights": weights, + "ignore_index": ignore_index, + }, + ) + + @pytest.mark.parametrize("weights", ["linear", "quadratic", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_confusion_matrix_functional(self, input, weights, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_cohen_kappa, + sk_metric=partial(_sk_cohen_kappa_multiclass, weights=weights, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "weights": weights, + "ignore_index": ignore_index, + }, + ) + + def test_multiclass_cohen_kappa_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassCohenKappa, + metric_functional=multiclass_cohen_kappa, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_cohen_kappa_dtypes_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassCohenKappa, + metric_functional=multiclass_cohen_kappa, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_dtypes_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassCohenKappa, + metric_functional=multiclass_cohen_kappa, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + # -------------------------- Old stuff -------------------------- # def _sk_cohen_kappa_binary_prob(preds, target, weights=None): From 5a75458ff2d1ce704cb64076fc158b2e1bb57d8d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 13 Jul 2022 15:54:02 +0200 Subject: [PATCH 65/74] working matthews --- .../classification/cohen_kappa.py | 4 +- .../classification/matthews_corrcoef.py | 187 +++++++- src/torchmetrics/functional/__init__.py | 6 +- .../classification/matthews_corrcoef.py | 213 +++++++- .../classification/test_matthews_corrcoef.py | 454 ++++++++++++++---- 5 files changed, 744 insertions(+), 120 deletions(-) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index e68588179cf..9a0ba35fd4f 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -19,10 +19,10 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix from torchmetrics.functional.classification.cohen_kappa import ( + _binary_cohen_kappa_arg_validation, _cohen_kappa_compute, - _cohen_kappa_update, _cohen_kappa_reduce, - _binary_cohen_kappa_arg_validation, + _cohen_kappa_update, _multiclass_cohen_kappa_arg_validation, ) from torchmetrics.metric import Metric diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 2198d934fac..dbfb46a7c05 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional import torch from torch import Tensor @@ -20,59 +20,202 @@ from torchmetrics.functional.classification.matthews_corrcoef import ( _matthews_corrcoef_compute, _matthews_corrcoef_update, + _matthews_corrcoef_reduce ) from torchmetrics.metric import Metric class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): - """""" + r"""Calculates `Matthews correlation coefficient`_ for binary tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics import BinaryMatthewsCorrCoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryMatthewsCorrCoef() + >>> metric(preds, target) + tensor(0.5774) + + Example (preds is float tensor): + >>> from torchmetrics import BinaryMatthewsCorrCoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryMatthewsCorrCoef() + >>> metric(preds, target) + tensor(0.5774) + + """ is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False - def __init__(self, **kwargs): - pass - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - pass + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs) def compute(self) -> Tensor: - pass + return _matthews_corrcoef_reduce(self.confmat) class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): - """""" + r"""Calculates `Matthews correlation coefficient`_ for multiclass tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics import MulticlassMatthewsCorrCoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) + >>> metric(preds, target) + tensor(0.7000) + + Example (pred is float tensor): + >>> from torchmetrics import MulticlassMatthewsCorrCoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassMatthewsCorrCoef(num_classes=3) + >>> metric(preds, target) + tensor(0.7000) + + """ is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False - def __init__(self, **kwargs): - pass - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - pass + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(num_classes, ignore_index, normalize=None, validate_args=validate_args, **kwargs) def compute(self) -> Tensor: - pass + return _matthews_corrcoef_reduce(self.confmat) class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): - """""" + r"""Calculates `Matthews correlation coefficient`_ for multilabel tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics import MultilabelMatthewsCorrCoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics import MultilabelMatthewsCorrCoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelMatthewsCorrCoef(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + + """ is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False - def __init__(self, **kwargs): - pass - - def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore - pass + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(num_labels, threshold, ignore_index, normalize=None, validate_args=validate_args, **kwargs) def compute(self) -> Tensor: - pass + return _matthews_corrcoef_reduce(self.confmat) # -------------------------- Old stuff -------------------------- diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index f571a3de833..3d55dd9bcd5 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -19,11 +19,7 @@ from torchmetrics.functional.classification.auroc import auroc from torchmetrics.functional.classification.average_precision import average_precision from torchmetrics.functional.classification.calibration_error import calibration_error -from torchmetrics.functional.classification.cohen_kappa import ( - binary_cohen_kappa, - cohen_kappa, - multiclass_cohen_kappa, -) +from torchmetrics.functional.classification.cohen_kappa import binary_cohen_kappa, cohen_kappa, multiclass_cohen_kappa from torchmetrics.functional.classification.confusion_matrix import ( binary_confusion_matrix, confusion_matrix, diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 29a83f3b06c..4db8a28dcb2 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -11,22 +11,221 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from torch import Tensor -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update +from torchmetrics.functional.classification.confusion_matrix import ( + _confusion_matrix_update, + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_update, +) + + +def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: + """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the matthews corrcoef score.""" + # convert multilabel into binary + confmat = confmat.sum(0) if confmat.ndim == 3 else confmat + + tk = confmat.sum(dim=-1).float() + pk = confmat.sum(dim=-2).float() + c = torch.trace(confmat).float() + s = confmat.sum().float() + + cov_ytyp = c * s - sum(tk * pk) + cov_ypyp = s**2 - sum(pk * pk) + cov_ytyt = s**2 - sum(tk * tk) + + if cov_ypyp * cov_ytyt == 0: + return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) + else: + return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp) + + +def binary_matthews_corrcoef( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Matthews correlation coefficient`_ for binary tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_matthews_corrcoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_matthews_corrcoef(preds, target) + tensor(0.5774) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_matthews_corrcoef + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_matthews_corrcoef(preds, target) + tensor(0.5774) + + """ + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize=None) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _matthews_corrcoef_reduce(confmat) + + +def multiclass_matthews_corrcoef( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Matthews correlation coefficient`_ for multiclass tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` -def binary_matthews_corrcoef(): - pass + Additional dimension ``...`` will be flattened into the batch dimension. + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. -def multiclass_matthews_corrcoef(): - pass + Example (pred is integer tensor): + >>> from torchmetrics.functional import multiclass_matthews_corrcoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) + tensor(0.7000) + Example (pred is float tensor): + >>> from torchmetrics.functional import multiclass_matthews_corrcoef + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_matthews_corrcoef(preds, target, num_classes=3) + tensor(0.7000) + + """ + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize=None) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _matthews_corrcoef_reduce(confmat) -def multilabel_matthews_corrcoef(): - pass + +def multilabel_matthews_corrcoef( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Calculates `Matthews correlation coefficient`_ for multilabel tasks. This metric measures + the general correlation or quality of a classification. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_matthews_corrcoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_matthews_corrcoef + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_matthews_corrcoef(preds, target, num_labels=3) + tensor(0.3333) + + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize=None) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _matthews_corrcoef_reduce(confmat) # -------------------------- Old stuff -------------------------- diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 2502fb25b8a..bab2ac2e382 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -11,139 +11,425 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import matthews_corrcoef as sk_matthews_corrcoef -from torchmetrics.classification.matthews_corrcoef import MatthewsCorrCoef -from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.matthews_corrcoef import ( + BinaryMatthewsCorrCoef, + MulticlassMatthewsCorrCoef, + MultilabelMatthewsCorrCoef, +) +from torchmetrics.functional.classification.matthews_corrcoef import ( + binary_matthews_corrcoef, + multiclass_matthews_corrcoef, + multilabel_matthews_corrcoef, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) -def _sk_matthews_corrcoef_binary_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) - - -def _sk_matthews_corrcoef_binary(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +def _sk_matthews_corrcoef_binary(preds, target, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_matthews_corrcoef(y_true=target, y_pred=preds) -def _sk_matthews_corrcoef_multilabel_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_matthews_corrcoef(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryMatthewsCorrCoef, + sk_metric=partial(_sk_matthews_corrcoef_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_matthews_corrcoef_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_matthews_corrcoef, + sk_metric=partial(_sk_matthews_corrcoef_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) -def _sk_matthews_corrcoef_multilabel(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + def test_binary_matthews_corrcoef_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryMatthewsCorrCoef, + metric_functional=binary_matthews_corrcoef, + metric_args={"threshold": THRESHOLD}, + ) - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_matthews_corrcoef_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryMatthewsCorrCoef, + metric_functional=binary_matthews_corrcoef, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_matthews_corrcoef_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryMatthewsCorrCoef, + metric_functional=binary_matthews_corrcoef, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_matthews_corrcoef_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +def _sk_matthews_corrcoef_multiclass(preds, target, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_matthews_corrcoef(y_true=target, y_pred=preds) -def _sk_matthews_corrcoef_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_matthews_corrcoef(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassMatthewsCorrCoef, + sk_metric=partial(_sk_matthews_corrcoef_multiclass, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_matthews_corrcoef_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_matthews_corrcoef, + sk_metric=partial(_sk_matthews_corrcoef_multiclass, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + }, + ) -def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() + def test_multiclass_matthews_corrcoef_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassMatthewsCorrCoef, + metric_functional=multiclass_matthews_corrcoef, + metric_args={"num_classes": NUM_CLASSES}, + ) - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_matthews_corrcoef_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassMatthewsCorrCoef, + metric_functional=multiclass_matthews_corrcoef, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_matthews_corrcoef_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassMatthewsCorrCoef, + metric_functional=multiclass_matthews_corrcoef, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -def _sk_matthews_corrcoef_multidim_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) +def _sk_matthews_corrcoef_multilabel(preds, target, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_matthews_corrcoef(y_true=target, y_pred=preds) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES), - ], -) -class TestMatthewsCorrCoef(MetricTester): +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelMatthewsCorrCoef(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_multilabel_matthews_corrcoef(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=MatthewsCorrCoef, - sk_metric=sk_metric, - dist_sync_on_step=dist_sync_on_step, + metric_class=MultilabelMatthewsCorrCoef, + sk_metric=partial(_sk_matthews_corrcoef_multilabel, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, }, ) - def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_matthews_corrcoef_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=matthews_corrcoef, - sk_metric=sk_metric, + preds=preds, + target=target, + metric_functional=multilabel_matthews_corrcoef, + sk_metric=partial(_sk_matthews_corrcoef_multilabel, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, }, ) - def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes): + def test_multilabel_matthews_corrcoef_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=MatthewsCorrCoef, - metric_functional=matthews_corrcoef, - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - }, + metric_module=MultilabelMatthewsCorrCoef, + metric_functional=multilabel_matthews_corrcoef, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_matthews_corrcoef_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelMatthewsCorrCoef, + metric_functional=multilabel_matthews_corrcoef, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -def test_zero_case(): + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_matthews_corrcoef_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelMatthewsCorrCoef, + metric_functional=multilabel_matthews_corrcoef, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + +def test_zero_case_in_multiclass(): """Cases where the denominator in the matthews corrcoef is 0, the score should return 0.""" # Example where neither 1 or 2 is present in the target tensor - out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) + out = multiclass_matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) assert out == 0.0 + + +# -------------------------- Old stuff -------------------------- + +# def _sk_matthews_corrcoef_binary_prob(preds, target): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_binary(preds, target): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_multilabel_prob(preds, target): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_multilabel(preds, target): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_multiclass_prob(preds, target): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_multiclass(preds, target): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_multidim_multiclass_prob(preds, target): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# def _sk_matthews_corrcoef_multidim_multiclass(preds, target): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_matthews_corrcoef(y_true=sk_target, y_pred=sk_preds) + + +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_matthews_corrcoef_binary_prob, 2), +# (_input_binary.preds, _input_binary.target, _sk_matthews_corrcoef_binary, 2), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_matthews_corrcoef_multilabel_prob, 2), +# (_input_mlb.preds, _input_mlb.target, _sk_matthews_corrcoef_multilabel, 2), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_matthews_corrcoef_multiclass_prob, NUM_CLASSES), +# (_input_mcls.preds, _input_mcls.target, _sk_matthews_corrcoef_multiclass, NUM_CLASSES), +# ( +# _input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_matthews_corrcoef_multidim_multiclass_prob, NUM_CLASSES +# ), +# (_input_mdmc.preds, _input_mdmc.target, _sk_matthews_corrcoef_multidim_multiclass, NUM_CLASSES), +# ], +# ) +# class TestMatthewsCorrCoef(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_matthews_corrcoef(self, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=MatthewsCorrCoef, +# sk_metric=sk_metric, +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# }, +# ) + +# def test_matthews_corrcoef_functional(self, preds, target, sk_metric, num_classes): +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=matthews_corrcoef, +# sk_metric=sk_metric, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# }, +# ) + +# def test_matthews_corrcoef_differentiability(self, preds, target, sk_metric, num_classes): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=MatthewsCorrCoef, +# metric_functional=matthews_corrcoef, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# }, +# ) + + +# def test_zero_case(): +# """Cases where the denominator in the matthews corrcoef is 0, the score should return 0.""" +# # Example where neither 1 or 2 is present in the target tensor +# out = matthews_corrcoef(torch.tensor([0, 1, 2]), torch.tensor([0, 0, 0]), 3) +# assert out == 0.0 From 7a98b334f3db38ab9cbb4725c07066a85ac82655 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Jul 2022 13:57:41 +0000 Subject: [PATCH 66/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/matthews_corrcoef.py | 2 +- .../functional/classification/matthews_corrcoef.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index dbfb46a7c05..4b16ed17d02 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -19,8 +19,8 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import ( _matthews_corrcoef_compute, + _matthews_corrcoef_reduce, _matthews_corrcoef_update, - _matthews_corrcoef_reduce ) from torchmetrics.metric import Metric diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 4db8a28dcb2..de278cd7692 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -17,24 +17,25 @@ from torch import Tensor from torchmetrics.functional.classification.confusion_matrix import ( - _confusion_matrix_update, _binary_confusion_matrix_arg_validation, - _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, - _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, _multilabel_confusion_matrix_arg_validation, - _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: - """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the matthews corrcoef score.""" + """Reduce an un-normalized confusion matrix of shape (n_classes, n_classes) into the matthews corrcoef + score.""" # convert multilabel into binary confmat = confmat.sum(0) if confmat.ndim == 3 else confmat From f8196013e2ae39fd45136fb4aefef5228af3a901 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 14 Jul 2022 11:59:39 +0200 Subject: [PATCH 67/74] working jaccard --- src/torchmetrics/classification/jaccard.py | 43 +- .../classification/matthews_corrcoef.py | 2 +- .../functional/classification/f_beta.py | 7 +- .../functional/classification/jaccard.py | 137 +++- .../classification/matthews_corrcoef.py | 8 +- src/torchmetrics/utilities/compute.py | 6 + .../unittests/classification/test_jaccard.py | 680 +++++++++++++----- 7 files changed, 658 insertions(+), 225 deletions(-) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 8c08582f8fd..d02fcd643f5 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Literal, Optional import torch from torch import Tensor from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat +from torchmetrics.functional.classification.jaccard import ( + _jaccard_from_confmat, + _jaccard_index_reduce, + _multiclass_jaccard_index_arg_validation, + _multilabel_jaccard_index_arg_validation, +) class BinaryJaccardIndex(BinaryConfusionMatrix): - """""" + """ """ is_differentiable: bool = False higher_is_better: bool = True @@ -35,14 +40,12 @@ def __init__( validate_args: bool = True, **kwargs: Any, ) -> None: - if validate_args: - _binary_jaccard_index_validate_args(threshold, ignore_index) - super().__init__(threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=False, **kwargs) + super().__init__( + threshold=threshold, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs + ) def compute(self) -> Tensor: - return _binary_jaccard_index_compute( - self.confmat, - ) + return _jaccard_index_reduce(self.confmat, average="binary") class MulticlassJaccardIndex(MulticlassConfusionMatrix): @@ -56,17 +59,20 @@ def __init__( self, num_classes: int, ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__( - num_classes=num_classes, ignore_index=ignore_index, normalize=None, validate_args=validate_args, **kwargs + num_classes=num_classes, ignore_index=ignore_index, normalize=None, validate_args=False, **kwargs ) + if validate_args: + _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average) + self.validate_args = validate_args + self.average = average def compute(self) -> Tensor: - return _multiclass_jaccard_index_compute( - self.confmat, - ) + return _jaccard_index_reduce(self.confmat, average=self.average) class MultilabelJaccardIndex(MultilabelConfusionMatrix): @@ -81,6 +87,7 @@ def __init__( num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", validate_args: bool = True, **kwargs: Any, ) -> None: @@ -89,14 +96,16 @@ def __init__( threshold=threshold, ignore_index=ignore_index, normalize=None, - validate_args=validate_args, + validate_args=False, **kwargs, ) + if validate_args: + _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index, average) + self.validate_args = validate_args + self.average = average def compute(self) -> Tensor: - return _multilabel_jaccard_index_compute( - self.confmat, - ) + return _jaccard_index_reduce(self.confmat, average=self.average) # -------------------------- Old stuff -------------------------- diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index dbfb46a7c05..4b16ed17d02 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -19,8 +19,8 @@ from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.functional.classification.matthews_corrcoef import ( _matthews_corrcoef_compute, + _matthews_corrcoef_reduce, _matthews_corrcoef_update, - _matthews_corrcoef_reduce ) from torchmetrics.metric import Metric diff --git a/src/torchmetrics/functional/classification/f_beta.py b/src/torchmetrics/functional/classification/f_beta.py index e523b4b533d..575872b33be 100644 --- a/src/torchmetrics/functional/classification/f_beta.py +++ b/src/torchmetrics/functional/classification/f_beta.py @@ -17,16 +17,11 @@ from torch import Tensor from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod as AvgMethod from torchmetrics.utilities.enums import MDMCAverageMethod -def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: - """prevent zero division.""" - denom[denom == 0.0] = 1 - return num / denom - - def _fbeta_compute( tp: Tensor, fp: Tensor, diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 2d89088522b..eeff298aa62 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -15,20 +15,143 @@ import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) +from torchmetrics.utilities.compute import _safe_divide + + +def _jaccard_index_reduce( + confmat: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]], +) -> Tensor: + """Perform reduction of an un-normalized confusion matrix into jaccard score + + Args: + confmat: tensor with un-normalized confusionmatrix + average: reduction method + + - ``'binary'``: binary reduction, expects a 2x2 matrix + - ``'macro'``: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. + + """ + allowed_average = ["binary", "micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if average == "binary": + return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) + else: + if confmat.ndim == 3: # multilabel + num = confmat[:, 1, 1] + denom = confmat[:, 1, 1] + confmat[:, 0, 1] + confmat[:, 1, 0] + else: # multiclass + num = torch.diag(confmat) + denom = confmat.sum(0) + confmat.sum(1) - num + + if average == "micro": + num = num.sum() + denom = denom.sum() + + jaccard = _safe_divide(num, denom) + + if average is None or average == "none": + return jaccard + if average == "weighted": + weights = confmat[:, 1, 1] + confmat[:, 1, 0] if confmat.ndim == 3 else confmat.sum(1) + else: + weights = torch.ones_like(jaccard) + return ((weights * jaccard) / weights.sum()).sum() + + +def binary_jaccard_index( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _jaccard_index_reduce(confmat, average="binary") + -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update +def _multiclass_jaccard_index_arg_validation( + num_classes: int, + ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = None, +) -> None: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}.") -def binary_jaccard_index(): - pass +def multiclass_jaccard_index( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _jaccard_index_reduce(confmat, average=average) -def multiclass_jaccard_index(): - pass +def _multilabel_jaccard_index_arg_validation( + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", +) -> None: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}.") -def multilabel_jaccard_index(): - pass +def multilabel_jaccard_index( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", + validate_args: bool = True, +) -> Tensor: + if validate_args: + _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _jaccard_index_reduce(confmat, average=average) # -------------------------- Old stuff -------------------------- diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 4db8a28dcb2..f7195ed325b 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -17,18 +17,18 @@ from torch import Tensor from torchmetrics.functional.classification.confusion_matrix import ( - _confusion_matrix_update, _binary_confusion_matrix_arg_validation, - _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, _binary_confusion_matrix_update, + _confusion_matrix_update, _multiclass_confusion_matrix_arg_validation, - _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, _multiclass_confusion_matrix_update, _multilabel_confusion_matrix_arg_validation, - _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, _multilabel_confusion_matrix_update, ) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index f496baff818..a57f47d1e4e 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -38,3 +38,9 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: res = x * torch.log(y) res[x == 0] = 0.0 return res + + +def _safe_divide(num: Tensor, denom: Tensor) -> Tensor: + """prevent zero division.""" + denom[denom == 0.0] = 1 + return num / denom diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 945001a9220..2b5e7ff9741 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -16,224 +16,524 @@ import numpy as np import pytest import torch -from sklearn.metrics import jaccard_score as sk_jaccard_score -from torch import Tensor, tensor - -from torchmetrics.classification.jaccard import JaccardIndex -from torchmetrics.functional import jaccard_index -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester - - -def _sk_jaccard_binary_prob(preds, target, average=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_jaccard_binary(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_jaccard_multilabel_prob(preds, target, average=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) - - -def _sk_jaccard_multilabel(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +from sklearn.metrics import jaccard_score as sk_jaccard_index, confusion_matrix as sk_confusion_matrix +from scipy.special import expit as sigmoid +from torchmetrics.classification.jaccard import ( + BinaryJaccardIndex, + MulticlassJaccardIndex, + MultilabelJaccardIndex, +) +from torchmetrics.functional.classification.jaccard import ( + binary_jaccard_index, + multiclass_jaccard_index, + multilabel_jaccard_index, +) +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 + + +def _sk_jaccard_index_binary(preds, target, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_jaccard_index(y_true=target, y_pred=preds) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryJaccardIndex(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_jaccard_index(self, input, ddp, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryJaccardIndex, + sk_metric=partial(_sk_jaccard_index_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_jaccard_index_functional(self, input, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_jaccard_index, + sk_metric=partial(_sk_jaccard_index_binary, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + }, + ) + def test_binary_jaccard_index_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryJaccardIndex, + metric_functional=binary_jaccard_index, + metric_args={"threshold": THRESHOLD}, + ) -def _sk_jaccard_multiclass_prob(preds, target, average=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_jaccard_index_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryJaccardIndex, + metric_functional=binary_jaccard_index, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_jaccard_index_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryJaccardIndex, + metric_functional=binary_jaccard_index, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_jaccard_multiclass(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +def _sk_jaccard_index_multiclass(preds, target, ignore_index=None, average="macro"): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_jaccard_index(y_true=target, y_pred=preds, average=average) -def _sk_jaccard_multidim_multiclass_prob(preds, target, average=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassJaccardIndex(MetricTester): + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_jaccard_index(self, input, ddp, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassJaccardIndex, + sk_metric=partial(_sk_jaccard_index_multiclass, ignore_index=ignore_index, average=average), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_jaccard_index_functional(self, input, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_jaccard_index, + sk_metric=partial(_sk_jaccard_index_multiclass, ignore_index=ignore_index, average=average), + metric_args={ + "num_classes": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, + ) + def test_multiclass_jaccard_index_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassJaccardIndex, + metric_functional=multiclass_jaccard_index, + metric_args={"num_classes": NUM_CLASSES}, + ) -def _sk_jaccard_multidim_multiclass(preds, target, average=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_jaccard_index_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassJaccardIndex, + metric_functional=multiclass_jaccard_index, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_jaccard_index_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassJaccardIndex, + metric_functional=multiclass_jaccard_index, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_jaccard_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_jaccard_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_jaccard_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_jaccard_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_jaccard_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_jaccard_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_jaccard_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_jaccard_multidim_multiclass, NUM_CLASSES), - ], -) -class TestJaccardIndex(MetricTester): +def _sk_jaccard_index_multilabel(preds, target, ignore_index=None, average="macro"): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + if ignore_index is not None: + if average == "micro": + return _sk_jaccard_index_binary(torch.tensor(preds), torch.tensor(target), ignore_index) + scores, weights = [], [] + for i in range(preds.shape[1]): + p, t = preds[:, i], target[:, i] + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + confmat = sk_confusion_matrix(t, p, labels=[0, 1]) + scores.append(sk_jaccard_index(t, p)) + weights.append(confmat[1, 0] + confmat[1, 1]) + scores = np.stack(scores, axis=0) + weights = np.stack(weights, axis=0) + if average is None or average == "none": + return scores + elif average == "macro": + return scores.mean() + return ((scores * weights) / weights.sum()).sum() + else: + return sk_jaccard_index(y_true=target, y_pred=preds, average=average) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelJaccardIndex(MetricTester): + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None]) # , -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - # average = "macro" if reduction == "elementwise_mean" else None # convert tags + def test_multilabel_jaccard_index(self, input, ddp, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=JaccardIndex, - sk_metric=partial(sk_metric, average=average), - dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, + metric_class=MultilabelJaccardIndex, + sk_metric=partial(_sk_jaccard_index_multilabel, ignore_index=ignore_index, average=average), + metric_args={ + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, ) - def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes): - # average = "macro" if reduction == "elementwise_mean" else None # convert tags + @pytest.mark.parametrize("average", ["macro", "micro", "weighted", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_jaccard_index_functional(self, input, ignore_index, average): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( - preds, - target, - metric_functional=jaccard_index, - sk_metric=partial(sk_metric, average=average), - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, + preds=preds, + target=target, + metric_functional=multilabel_jaccard_index, + sk_metric=partial(_sk_jaccard_index_multilabel, ignore_index=ignore_index, average=average), + metric_args={ + "num_labels": NUM_CLASSES, + "ignore_index": ignore_index, + "average": average, + }, ) - def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes): + def test_multilabel_jaccard_index_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=JaccardIndex, - metric_functional=jaccard_index, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, + metric_module=MultilabelJaccardIndex, + metric_functional=multilabel_jaccard_index, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_jaccard_index_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelJaccardIndex, + metric_functional=multilabel_jaccard_index, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize( - ["half_ones", "average", "ignore_index", "expected"], - [ - (False, "none", None, Tensor([1, 1, 1])), - (False, "macro", None, Tensor([1])), - (False, "none", 0, Tensor([1, 1])), - (True, "none", None, Tensor([0.5, 0.5, 0.5])), - (True, "macro", None, Tensor([0.5])), - (True, "none", 0, Tensor([2 / 3, 1 / 2])), - ], -) -def test_jaccard(half_ones, average, ignore_index, expected): - preds = (torch.arange(120) % 3).view(-1, 1) - target = (torch.arange(120) % 3).view(-1, 1) - if half_ones: - preds[:60] = 1 - jaccard_val = jaccard_index( - preds=preds, - target=target, - average=average, - num_classes=3, - ignore_index=ignore_index, - # reduction=reduction, - ) - assert torch.allclose(jaccard_val, expected, atol=1e-9) - - -# test `absent_score` -@pytest.mark.parametrize( - ["pred", "target", "ignore_index", "absent_score", "num_classes", "expected"], - [ - # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid - # scores the function can return ([0., 1.] range, inclusive). - # 2 classes, class 0 is correct everywhere, class 1 is absent. - ([0], [0], None, -1.0, 2, [1.0, -1.0]), - ([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]), - # absent_score not applied if only class 0 is present and it's the only class. - ([0], [0], None, -1.0, 1, [1.0]), - # 2 classes, class 1 is correct everywhere, class 0 is absent. - ([1], [1], None, -1.0, 2, [-1.0, 1.0]), - ([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]), - # When 0 index ignored, class 0 does not get a score (not even the absent_score). - ([1], [1], 0, -1.0, 2, [1.0]), - # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. - ([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]), - ([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]), - # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. - ([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]), - ([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class - # 2 is absent. - ([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]), - # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class - # 2 is absent. - ([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]), - # Sanity checks with absent_score of 1.0. - ([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]), - ([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), - ], -) -def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): - jaccard_val = jaccard_index( - preds=tensor(pred), - target=tensor(target), - average=None, - ignore_index=ignore_index, - absent_score=absent_score, - num_classes=num_classes, - # reduction="none", - ) - assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) - - -# example data taken from -# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py -@pytest.mark.parametrize( - ["pred", "target", "ignore_index", "num_classes", "average", "expected"], - [ - # Ignoring an index outside of [0, num_classes-1] should have no effect. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]), - # Ignoring a valid index drops only that index from the result. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]), - # When reducing to mean or sum, the ignored index does not contribute to the output. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]), - # ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), - ], -) -def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected): - jaccard_val = jaccard_index( - preds=tensor(pred), - target=tensor(target), - average=average, - ignore_index=ignore_index, - num_classes=num_classes, - # reduction=reduction, - ) - assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_jaccard_index_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelJaccardIndex, + metric_functional=multilabel_jaccard_index, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + +# -------------------------- Old stuff -------------------------- + +# def _sk_jaccard_binary_prob(preds, target, average=None): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_binary(preds, target, average=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_multilabel_prob(preds, target, average=None): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_multilabel(preds, target, average=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_multiclass_prob(preds, target, average=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_multiclass(preds, target, average=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_multidim_multiclass_prob(preds, target, average=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# def _sk_jaccard_multidim_multiclass(preds, target, average=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) + + +# @pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"]) +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_jaccard_binary_prob, 2), +# (_input_binary.preds, _input_binary.target, _sk_jaccard_binary, 2), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_jaccard_multilabel_prob, 2), +# (_input_mlb.preds, _input_mlb.target, _sk_jaccard_multilabel, 2), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_jaccard_multiclass_prob, NUM_CLASSES), +# (_input_mcls.preds, _input_mcls.target, _sk_jaccard_multiclass, NUM_CLASSES), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_jaccard_multidim_multiclass_prob, NUM_CLASSES), +# (_input_mdmc.preds, _input_mdmc.target, _sk_jaccard_multidim_multiclass, NUM_CLASSES), +# ], +# ) +# class TestJaccardIndex(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): +# # average = "macro" if reduction == "elementwise_mean" else None # convert tags +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=JaccardIndex, +# sk_metric=partial(sk_metric, average=average), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, +# ) + +# def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes): +# # average = "macro" if reduction == "elementwise_mean" else None # convert tags +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=jaccard_index, +# sk_metric=partial(sk_metric, average=average), +# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, +# ) + +# def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=JaccardIndex, +# metric_functional=jaccard_index, +# metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, +# ) + + +# @pytest.mark.parametrize( +# ["half_ones", "average", "ignore_index", "expected"], +# [ +# (False, "none", None, Tensor([1, 1, 1])), +# (False, "macro", None, Tensor([1])), +# (False, "none", 0, Tensor([1, 1])), +# (True, "none", None, Tensor([0.5, 0.5, 0.5])), +# (True, "macro", None, Tensor([0.5])), +# (True, "none", 0, Tensor([2 / 3, 1 / 2])), +# ], +# ) +# def test_jaccard(half_ones, average, ignore_index, expected): +# preds = (torch.arange(120) % 3).view(-1, 1) +# target = (torch.arange(120) % 3).view(-1, 1) +# if half_ones: +# preds[:60] = 1 +# jaccard_val = jaccard_index( +# preds=preds, +# target=target, +# average=average, +# num_classes=3, +# ignore_index=ignore_index, +# # reduction=reduction, +# ) +# assert torch.allclose(jaccard_val, expected, atol=1e-9) + + +# # test `absent_score` +# @pytest.mark.parametrize( +# ["pred", "target", "ignore_index", "absent_score", "num_classes", "expected"], +# [ +# # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid +# # scores the function can return ([0., 1.] range, inclusive). +# # 2 classes, class 0 is correct everywhere, class 1 is absent. +# ([0], [0], None, -1.0, 2, [1.0, -1.0]), +# ([0, 0], [0, 0], None, -1.0, 2, [1.0, -1.0]), +# # absent_score not applied if only class 0 is present and it's the only class. +# ([0], [0], None, -1.0, 1, [1.0]), +# # 2 classes, class 1 is correct everywhere, class 0 is absent. +# ([1], [1], None, -1.0, 2, [-1.0, 1.0]), +# ([1, 1], [1, 1], None, -1.0, 2, [-1.0, 1.0]), +# # When 0 index ignored, class 0 does not get a score (not even the absent_score). +# ([1], [1], 0, -1.0, 2, [1.0]), +# # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. +# ([0, 2], [0, 2], None, -1.0, 3, [1.0, -1.0, 1.0]), +# ([2, 0], [2, 0], None, -1.0, 3, [1.0, -1.0, 1.0]), +# # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. +# ([0, 1], [0, 1], None, -1.0, 3, [1.0, 1.0, -1.0]), +# ([1, 0], [1, 0], None, -1.0, 3, [1.0, 1.0, -1.0]), +# # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class +# # 2 is absent. +# ([0, 1], [0, 0], None, -1.0, 3, [0.5, 0.0, -1.0]), +# # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class +# # 2 is absent. +# ([0, 0], [0, 1], None, -1.0, 3, [0.5, 0.0, -1.0]), +# # Sanity checks with absent_score of 1.0. +# ([0, 2], [0, 2], None, 1.0, 3, [1.0, 1.0, 1.0]), +# ([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), +# ], +# ) +# def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): +# jaccard_val = jaccard_index( +# preds=tensor(pred), +# target=tensor(target), +# average=None, +# ignore_index=ignore_index, +# absent_score=absent_score, +# num_classes=num_classes, +# # reduction="none", +# ) +# assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) + + +# # example data taken from +# # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py +# @pytest.mark.parametrize( +# ["pred", "target", "ignore_index", "num_classes", "average", "expected"], +# [ +# # Ignoring an index outside of [0, num_classes-1] should have no effect. +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, "none", [1, 1 / 2, 2 / 3]), +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, "none", [1, 1 / 2, 2 / 3]), +# # Ignoring a valid index drops only that index from the result. +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "none", [1 / 2, 2 / 3]), +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]), +# # When reducing to mean or sum, the ignored index does not contribute to the output. +# ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]), +# # ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), +# ], +# ) +# def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected): +# jaccard_val = jaccard_index( +# preds=tensor(pred), +# target=tensor(target), +# average=average, +# ignore_index=ignore_index, +# num_classes=num_classes, +# # reduction=reduction, +# ) +# assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) From 558cb682badd0d63ede56acf52db127fd1d2971c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Jul 2022 10:01:13 +0000 Subject: [PATCH 68/74] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/classification/jaccard.py | 2 +- .../functional/classification/jaccard.py | 3 +-- tests/unittests/classification/test_jaccard.py | 12 +++++------- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index d02fcd643f5..2f962082f22 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -27,7 +27,7 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): - """ """ + """""" is_differentiable: bool = False higher_is_better: bool = True diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index eeff298aa62..0461c966517 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -39,7 +39,7 @@ def _jaccard_index_reduce( confmat: Tensor, average: Optional[Literal["micro", "macro", "weighted", "none", "binary"]], ) -> Tensor: - """Perform reduction of an un-normalized confusion matrix into jaccard score + """Perform reduction of an un-normalized confusion matrix into jaccard score. Args: confmat: tensor with un-normalized confusionmatrix @@ -53,7 +53,6 @@ def _jaccard_index_reduce( metrics across classes, weighting each class by its support (``tp + fn``). - ``'none'`` or ``None``: Calculate the metric for each class separately, and return the metric for every class. - """ allowed_average = ["binary", "micro", "macro", "weighted", "none", None] if average not in allowed_average: diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index 2b5e7ff9741..8559b0d96d0 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -16,21 +16,19 @@ import numpy as np import pytest import torch -from sklearn.metrics import jaccard_score as sk_jaccard_index, confusion_matrix as sk_confusion_matrix from scipy.special import expit as sigmoid -from torchmetrics.classification.jaccard import ( - BinaryJaccardIndex, - MulticlassJaccardIndex, - MultilabelJaccardIndex, -) +from sklearn.metrics import confusion_matrix as sk_confusion_matrix +from sklearn.metrics import jaccard_score as sk_jaccard_index + +from torchmetrics.classification.jaccard import BinaryJaccardIndex, MulticlassJaccardIndex, MultilabelJaccardIndex from torchmetrics.functional.classification.jaccard import ( binary_jaccard_index, multiclass_jaccard_index, multilabel_jaccard_index, ) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index -from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 def _sk_jaccard_index_binary(preds, target, ignore_index=None): From d0abe80fdbba07cbc368b3f624ba74c44bb16fcf Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 14 Jul 2022 12:22:28 +0200 Subject: [PATCH 69/74] docstrings for jaccard --- src/torchmetrics/classification/__init__.py | 25 +-- src/torchmetrics/classification/jaccard.py | 151 +++++++++++++++++- .../functional/classification/jaccard.py | 143 +++++++++++++++++ 3 files changed, 304 insertions(+), 15 deletions(-) diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 9ec9648a58b..39be39d8f87 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -11,6 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.classification.confusion_matrix import ( # noqa: F401 isort:skip + BinaryConfusionMatrix, + ConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) +from torchmetrics.classification.stat_scores import ( # noqa: F401 isort:skip + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) + from torchmetrics.classification.accuracy import Accuracy # noqa: F401 from torchmetrics.classification.auc import AUC # noqa: F401 from torchmetrics.classification.auroc import AUROC # noqa: F401 @@ -20,12 +33,6 @@ from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 from torchmetrics.classification.cohen_kappa import BinaryCohenKappa, CohenKappa, MulticlassCohenKappa # noqa: F401 -from torchmetrics.classification.confusion_matrix import ( # noqa: F401 - BinaryConfusionMatrix, - ConfusionMatrix, - MulticlassConfusionMatrix, - MultilabelConfusionMatrix, -) from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 @@ -52,9 +59,3 @@ ) from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.specificity import Specificity # noqa: F401 -from torchmetrics.classification.stat_scores import ( # noqa: F401 - BinaryStatScores, - MulticlassStatScores, - MultilabelStatScores, - StatScores, -) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 2f962082f22..4879050049d 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -27,8 +27,52 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): - """""" + r"""Calculates the Jaccard index for binary tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics import BinaryJaccardIndex + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryJaccardIndex() + >>> metric(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics import BinaryJaccardIndex + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryJaccardIndex() + >>> metric(preds, target) + tensor(0.5000) + """ is_differentiable: bool = False higher_is_better: bool = True full_state_update: bool = False @@ -49,7 +93,59 @@ def compute(self) -> Tensor: class MulticlassJaccardIndex(MulticlassConfusionMatrix): - """""" + r"""Calculates the Jaccard index for multiclass tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics import MulticlassJaccardIndex + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassJaccardIndex(num_classes=3) + >>> metric(preds, target) + tensor(0.6667) + + Example (pred is float tensor): + >>> from torchmetrics import MulticlassJaccardIndex + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassJaccardIndex(num_classes=3) + >>> metric(preds, target) + tensor(0.6667) + """ is_differentiable: bool = False higher_is_better: bool = True @@ -76,7 +172,56 @@ def compute(self) -> Tensor: class MultilabelJaccardIndex(MultilabelConfusionMatrix): - """""" + r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics import MultilabelJaccardIndex + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelJaccardIndex(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics import MultilabelJaccardIndex + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelJaccardIndex(num_labels=3) + >>> metric(preds, target) + tensor(0.5000) + + """ is_differentiable: bool = False higher_is_better: bool = True diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 0461c966517..cbac8e3c328 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -89,6 +89,50 @@ def binary_jaccard_index( ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Tensor: + r"""Calculates the Jaccard index for binary tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_jaccard_index + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_jaccard_index(preds, target) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_jaccard_index + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_jaccard_index(preds, target) + tensor(0.5000) + """ if validate_args: _binary_confusion_matrix_arg_validation(threshold, ignore_index) _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) @@ -116,6 +160,57 @@ def multiclass_jaccard_index( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", validate_args: bool = True, ) -> Tensor: + r"""Calculates the Jaccard index for multiclass tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (pred is integer tensor): + >>> from torchmetrics.functional import multiclass_jaccard_index + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_jaccard_index(preds, target, num_classes=3) + tensor(0.6667) + + Example (pred is float tensor): + >>> from torchmetrics.functional import multiclass_jaccard_index + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_jaccard_index(preds, target, num_classes=3) + tensor(0.6667) + """ if validate_args: _multiclass_jaccard_index_arg_validation(num_classes, ignore_index, average) _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) @@ -145,6 +240,54 @@ def multilabel_jaccard_index( average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro", validate_args: bool = True, ) -> Tensor: + r"""Calculates the Jaccard index for multilabel tasks. The `Jaccard index`_ (also known as + the intersetion over union or jaccard similarity coefficient) is an statistic that can be + used to determine the similarity and diversity of a sample set. It is defined as the size + of the intersection divided by the union of the sample sets: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_jaccard_index + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_jaccard_index(preds, target, num_labels=3) + tensor(0.5000) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_jaccard_index + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_jaccard_index(preds, target, num_labels=3) + tensor(0.5000) + + """ if validate_args: _multilabel_jaccard_index_arg_validation(num_labels, threshold, ignore_index) _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) From 43a561663e89974fb22378a0eb007283f1c04080 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 14 Jul 2022 12:34:02 +0200 Subject: [PATCH 70/74] small improve --- .../functional/classification/matthews_corrcoef.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index de278cd7692..130e0e91a5e 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -48,10 +48,11 @@ def _matthews_corrcoef_reduce(confmat: Tensor) -> Tensor: cov_ypyp = s**2 - sum(pk * pk) cov_ytyt = s**2 - sum(tk * tk) - if cov_ypyp * cov_ytyt == 0: + denom = cov_ypyp * cov_ytyt + if denom == 0: return torch.tensor(0, dtype=confmat.dtype, device=confmat.device) else: - return cov_ytyp / torch.sqrt(cov_ytyt * cov_ypyp) + return cov_ytyp / torch.sqrt(denom) def binary_matthews_corrcoef( From 3e814a06236f59f9b958dfa47455d4d4146a8f85 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 14 Jul 2022 12:35:00 +0200 Subject: [PATCH 71/74] typing --- src/torchmetrics/classification/jaccard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 4879050049d..0c58d961de3 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Literal, Optional +from typing import Any, Optional import torch from torch import Tensor +from typing_extensions import Literal from torchmetrics.classification import BinaryConfusionMatrix, MulticlassConfusionMatrix, MultilabelConfusionMatrix from torchmetrics.classification.confusion_matrix import ConfusionMatrix From 59162f7be72f0d0c0e24229520a0aeaa369f15c2 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 14 Jul 2022 14:35:26 +0200 Subject: [PATCH 72/74] fix doctest --- src/torchmetrics/functional/classification/jaccard.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index cbac8e3c328..481b1441353 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -57,6 +57,7 @@ def _jaccard_index_reduce( allowed_average = ["binary", "micro", "macro", "weighted", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + confmat = confmat.float() if average == "binary": return confmat[1, 1] / (confmat[0, 1] + confmat[1, 0] + confmat[1, 1]) else: From 4c7de2257e4f4b58be6738271d79fa67c505ea9f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Fri, 15 Jul 2022 10:11:57 +0200 Subject: [PATCH 73/74] try something --- tests/unittests/helpers/testers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 288a39e4d2a..cc014c56e4a 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -628,7 +628,12 @@ def compute(self): def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: + """Utility function for injecting the ignore index value into a tensor randomly.""" + if any(x.flatten() == ignore_index): # ignore index is a class label + return x idx = torch.randperm(x.numel()) x = deepcopy(x) - x.view(-1)[idx[::5]] = ignore_index + # randomly set either element {3, 4, 5} to the ignore index value + skip = torch.randint(3, 6, (1,)).item() + x.view(-1)[idx[::skip]] = ignore_index return x From 2f02092bb66bdf938f33e2a23cdc7d955a6e7ce5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Fri, 15 Jul 2022 11:42:43 +0200 Subject: [PATCH 74/74] Apply suggestions from code review --- src/torchmetrics/classification/cohen_kappa.py | 2 ++ src/torchmetrics/classification/matthews_corrcoef.py | 3 +++ src/torchmetrics/functional/classification/cohen_kappa.py | 1 + src/torchmetrics/functional/classification/jaccard.py | 1 + .../functional/classification/matthews_corrcoef.py | 3 +++ 5 files changed, 10 insertions(+) diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 9a0ba35fd4f..734b7840a6c 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -58,6 +58,7 @@ class labels. - ``None`` or ``'none'``: no weighting - ``'linear'``: linear weighting - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -131,6 +132,7 @@ class labels. - ``None`` or ``'none'``: no weighting - ``'linear'``: linear weighting - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 4b16ed17d02..50c8ed6872b 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -48,6 +48,7 @@ class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -110,6 +111,7 @@ class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -178,6 +180,7 @@ class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/functional/classification/cohen_kappa.py b/src/torchmetrics/functional/classification/cohen_kappa.py index 10ff193fdf4..c72ab4a696b 100644 --- a/src/torchmetrics/functional/classification/cohen_kappa.py +++ b/src/torchmetrics/functional/classification/cohen_kappa.py @@ -197,6 +197,7 @@ class labels. - ``None`` or ``'none'``: no weighting - ``'linear'``: linear weighting - ``'quadratic'``: quadratic weighting + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/functional/classification/jaccard.py b/src/torchmetrics/functional/classification/jaccard.py index 481b1441353..159af3a92c8 100644 --- a/src/torchmetrics/functional/classification/jaccard.py +++ b/src/torchmetrics/functional/classification/jaccard.py @@ -116,6 +116,7 @@ def binary_jaccard_index( - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/functional/classification/matthews_corrcoef.py b/src/torchmetrics/functional/classification/matthews_corrcoef.py index 130e0e91a5e..5f50f7fcf1d 100644 --- a/src/torchmetrics/functional/classification/matthews_corrcoef.py +++ b/src/torchmetrics/functional/classification/matthews_corrcoef.py @@ -84,6 +84,7 @@ def binary_matthews_corrcoef( - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -140,6 +141,7 @@ def multiclass_matthews_corrcoef( - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -203,6 +205,7 @@ def multilabel_matthews_corrcoef( - ``'true'``: normalization over the targets (most commonly used) - ``'pred'``: normalization over the predictions - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.