diff --git a/CHANGELOG.md b/CHANGELOG.md index edc319511b195..2a46f49211268 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- `Accuracy` metric now generalizes to Top-k accuracy for (multi-dimensional) multi-class inputs using the `top_k` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + +- `Accuracy` metric now enables the computation of subset accuracy for multi-label or multi-dimensional multi-class inputs with the `subset_accuracy` parameter ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) + +- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) ### Changed diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 387cbc3bd7482..5396403402072 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -292,6 +292,12 @@ FBeta .. autoclass:: pytorch_lightning.metrics.classification.FBeta :noindex: +Hamming Distance +~~~~~~~~~~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.classification.HammingDistance + :noindex: + Precision ~~~~~~~~~ @@ -323,10 +329,9 @@ Functional Metrics (Classification) accuracy [func] ~~~~~~~~~~~~~~~ -.. autofunction:: pytorch_lightning.metrics.functional.classification.accuracy +.. autofunction:: pytorch_lightning.metrics.functional.accuracy :noindex: - auc [func] ~~~~~~~~~~ @@ -382,6 +387,11 @@ fbeta [func] .. autofunction:: pytorch_lightning.metrics.functional.fbeta :noindex: +hamming_distance [func] +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.hamming_distance + :noindex: iou [func] ~~~~~~~~~~ diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 2f28ad9abb0b7..c792fc5e71b03 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -15,6 +15,7 @@ from pytorch_lightning.metrics.classification import ( # noqa: F401 Accuracy, + HammingDistance, Precision, Recall, ConfusionMatrix, diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 79579035f6726..78163a9673887 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -15,6 +15,7 @@ from pytorch_lightning.metrics.classification.average_precision import AveragePrecision # noqa: F401 from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 from pytorch_lightning.metrics.classification.f_beta import FBeta, Fbeta, F1 # noqa: F401 +from pytorch_lightning.metrics.classification.hamming_distance import HammingDistance # noqa: F401 from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401 from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401 from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401 diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 330691a379574..e248c132026a4 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -16,35 +16,57 @@ import torch from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.utils import _input_format_classification +from pytorch_lightning.metrics.functional.accuracy import _accuracy_update, _accuracy_compute class Accuracy(Metric): r""" Computes `Accuracy `_: - .. math:: \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y_i}) + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a - tensor of predictions. Works with binary, multiclass, and multilabel - data. Accepts logits from a model output or integer class values in - prediction. Works with multi-dimensional preds and target. + tensor of predictions. - Forward accepts + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` + For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting ``subset_accuracy=True``. - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. - This is the case for binary and multi-label logits. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + Accepts all input types listed in :ref:`metrics:Input types`. Args: threshold: - Threshold value for binary or multi-label logits. default: 0.5 + Threshold probability value for transforming probability predictions to binary + `(0,1)` predictions, in the case of binary or multi-label inputs. + top_k: + Number of highest probability predictions considered to find the correct label, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). + + For multi-label inputs, if the parameter is set to `True`, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to `False`, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to `False`, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. default: True + Forward only calls ``update()`` and return None if this is set to False. dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False @@ -63,10 +85,19 @@ class Accuracy(Metric): >>> accuracy(preds, target) tensor(0.5000) + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy = Accuracy(top_k=2) + >>> accuracy(preds, target) + tensor(0.6667) + """ + def __init__( self, threshold: float = 0.5, + top_k: Optional[int] = None, + subset_accuracy: bool = False, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, @@ -82,24 +113,35 @@ def __init__( self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + if not 0 <= threshold <= 1: + raise ValueError("The `threshold` should lie in the [0,1] interval.") + + if top_k is not None and top_k <= 0: + raise ValueError("The `top_k` should be an integer larger than 1.") + self.threshold = threshold + self.top_k = top_k + self.subset_accuracy = subset_accuracy def update(self, preds: torch.Tensor, target: torch.Tensor): """ - Update state with predictions and targets. + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. Args: - preds: Predictions from model - target: Ground truth values + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels """ - preds, target = _input_format_classification(preds, target, self.threshold) - assert preds.shape == target.shape - self.correct += torch.sum(preds == target) - self.total += target.numel() + correct, total = _accuracy_update( + preds, target, threshold=self.threshold, top_k=self.top_k, subset_accuracy=self.subset_accuracy + ) + + self.correct += correct + self.total += total - def compute(self): + def compute(self) -> torch.Tensor: """ - Computes accuracy over state. + Computes accuracy based on inputs passed in to ``update`` previously. """ - return self.correct.float() / self.total + return _accuracy_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py new file mode 100644 index 0000000000000..b3281cd60987c --- /dev/null +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -0,0 +1,105 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_update, _hamming_distance_compute + + +class HammingDistance(Metric): + r""" + Computes the average `Hamming distance `_ (also + known as Hamming loss) between targets and predictions: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L}\sum_i^N \sum_l^L 1(y_{il} \neq \hat{y_{il}}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + threshold: + Threshold probability value for transforming probability predictions to binary + `(0,1)` predictions, in the case of binary or multi-label inputs. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When ``None``, DDP + will be used to perform the all gather. + + Example: + + >>> from pytorch_lightning.metrics import HammingDistance + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_distance = HammingDistance() + >>> hamming_distance(preds, target) + tensor(0.2500) + + """ + + def __init__( + self, + threshold: float = 0.5, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + if not 0 <= threshold <= 1: + raise ValueError("The `threshold` should lie in the [0,1] interval.") + self.threshold = threshold + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. See :ref:`metrics:Input types` for more information + on input types. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels + """ + correct, total = _hamming_distance_update(preds, target, self.threshold) + + self.correct += correct + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes hamming distance based on inputs passed in to ``update`` previously. + """ + return _hamming_distance_compute(self.correct, self.total) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index f52fb743b6979..bac9be59b1c9f 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -405,6 +405,11 @@ def _input_format_classification( else: preds, target = preds.squeeze(), target.squeeze() + # Convert half precision tensors to full precision, as not all ops are supported + # for example, min() is not supported + if preds.dtype == torch.float16: + preds = preds.float() + case = _check_classification_inputs( preds, target, diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 12812b75f8663..1b28f534f80e7 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from pytorch_lightning.metrics.functional.average_precision import average_precision # noqa: F401 from pytorch_lightning.metrics.functional.classification import ( # noqa: F401 - accuracy, auc, auroc, dice_score, @@ -32,8 +31,10 @@ ) from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401 # TODO: unify metrics between class and functional, add below +from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401 from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401 from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401 +from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401 from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_error import mean_squared_error # noqa: F401 from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py new file mode 100644 index 0000000000000..8ba0e49b881b8 --- /dev/null +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -0,0 +1,120 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Optional + +import torch +from pytorch_lightning.metrics.classification.helpers import _input_format_classification + + +def _accuracy_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float, top_k: Optional[int], subset_accuracy: bool +) -> Tuple[torch.Tensor, torch.Tensor]: + + preds, target, mode = _input_format_classification(preds, target, threshold=threshold, top_k=top_k) + + if mode == "binary" or (mode == "multi-label" and subset_accuracy): + correct = (preds == target).all(dim=1).sum() + total = torch.tensor(target.shape[0], device=target.device) + elif mode == "multi-label" and not subset_accuracy: + correct = (preds == target).sum() + total = torch.tensor(target.numel(), device=target.device) + elif mode == "multi-class" or (mode == "multi-dim multi-class" and not subset_accuracy): + correct = (preds * target).sum() + total = target.sum() + elif mode == "multi-dim multi-class" and subset_accuracy: + sample_correct = (preds * target).sum(dim=(1, 2)) + correct = (sample_correct == target.shape[2]).sum() + total = torch.tensor(target.shape[0], device=target.device) + + return correct, total + + +def _accuracy_compute(correct: torch.Tensor, total: torch.Tensor) -> torch.Tensor: + return correct.float() / total + + +def accuracy( + preds: torch.Tensor, + target: torch.Tensor, + threshold: float = 0.5, + top_k: Optional[int] = None, + subset_accuracy: bool = False, +) -> torch.Tensor: + r""" + Computes `Accuracy `_: + + .. math:: + \text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i) + + Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a + tensor of predictions. + + For multi-class and multi-dimensional multi-class data with probability predictions, the + parameter ``top_k`` generalizes this metric to a Top-K accuracy metric: for each sample the + top-K highest probability items are considered to find the correct label. + + For multi-label and multi-dimensional multi-class inputs, this metric computes the "global" + accuracy by default, which counts all labels or sub-samples separately. This can be + changed to subset accuracy (which requires all labels or sub-samples in the sample to + be correctly predicted) by setting ``subset_accuracy=True``. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model (probabilities, or labels) + target: Ground truth labels + threshold: + Threshold probability value for transforming probability predictions to binary + `(0,1)` predictions, in the case of binary or multi-label inputs. + top_k: + Number of highest probability predictions considered to find the correct label, relevant + only for (multi-dimensional) multi-class inputs with probability predictions. The + default value (``None``) will be interpreted as 1 for these inputs. + + Should be left at default (``None``) for all other types of inputs. + subset_accuracy: + Whether to compute subset accuracy for multi-label and multi-dimensional + multi-class inputs (has no effect for other input types). + + For multi-label inputs, if the parameter is set to `True`, then all labels for + each sample must be correctly predicted for the sample to count as correct. If it + is set to `False`, then all labels are counted separately - this is equivalent to + flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``). + + For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all + sub-sample (on the extra axis) must be correct for the sample to be counted as correct. + If it is set to `False`, then all sub-samples are counter separately - this is equivalent, + in the case of label predictions, to flattening the inputs beforehand (i.e. + ``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter + still applies in both cases, if set. + + Example: + + >>> from pytorch_lightning.metrics.functional import accuracy + >>> target = torch.tensor([0, 1, 2, 3]) + >>> preds = torch.tensor([0, 2, 1, 3]) + >>> accuracy(preds, target) + tensor(0.5000) + + >>> target = torch.tensor([0, 1, 2]) + >>> preds = torch.tensor([[0.1, 0.9, 0], [0.3, 0.1, 0.6], [0.2, 0.5, 0.3]]) + >>> accuracy(preds, target, top_k=2) + tensor(0.6667) + """ + + if top_k is not None and top_k <= 0: + raise ValueError("The `top_k` should be an integer larger than 1.") + + correct, total = _accuracy_update(preds, target, threshold, top_k, subset_accuracy) + return _accuracy_compute(correct, total) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 140dff7159da8..f4e66217c3466 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -229,47 +229,6 @@ def stat_scores_multiple_classes( return tps.float(), fps.float(), tns.float(), fns.float(), sups.float() -def accuracy( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, - class_reduction: str = 'micro', - return_state: bool = False -) -> torch.Tensor: - """ - Computes the accuracy classification score - - Args: - pred: predicted labels - target: ground truth labels - num_classes: number of classes - class_reduction: method to reduce metric score over labels - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'``: returns calculated metric per class - return_state: returns a internal state that can be ddp reduced - before doing the final calculation - - Return: - A Tensor with the accuracy score. - - Example: - - >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> accuracy(x, y) - tensor(0.7500) - - """ - tps, fps, tns, fns, sups = stat_scores_multiple_classes( - pred=pred, target=target, num_classes=num_classes) - if return_state: - return {'tps': tps, 'sups': sups} - return class_reduce(tps, sups, sups, class_reduction=class_reduction) - - def _confmat_normalize(cm): """ Normalization function for confusion matrix """ cm = cm / cm.sum(-1, keepdim=True) diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py new file mode 100644 index 0000000000000..7d8ecafd08b00 --- /dev/null +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -0,0 +1,71 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union + +import torch +from pytorch_lightning.metrics.classification.helpers import _input_format_classification + + +def _hamming_distance_update( + preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5 +) -> Tuple[torch.Tensor, int]: + preds, target, _ = _input_format_classification(preds, target, threshold=threshold) + + correct = (preds == target).sum() + total = preds.numel() + + return correct, total + + +def _hamming_distance_compute(correct: torch.Tensor, total: Union[int, torch.Tensor]) -> torch.Tensor: + return 1 - correct.float() / total + + +def hamming_distance(preds: torch.Tensor, target: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: + r""" + Computes the average `Hamming distance `_ (also + known as Hamming loss) between targets and predictions: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + This is the same as ``1-accuracy`` for binary data, while for all other types of inputs it + treats each possible label separately - meaning that, for example, multi-class data is + treated as if it were multi-label. + + Accepts all input types listed in :ref:`metrics:Input types`. + + Args: + preds: Predictions from model + target: Ground truth + threshold: + Threshold probability value for transforming probability predictions to binary + (0,1) predictions, in the case of binary or multi-label inputs. + + Example: + + >>> from pytorch_lightning.metrics.functional import hamming_distance + >>> target = torch.tensor([[0, 1], [1, 1]]) + >>> preds = torch.tensor([[0, 1], [0, 1]]) + >>> hamming_distance(preds, target) + tensor(0.2500) + + """ + + correct, total = _hamming_distance_update(preds, target, threshold) + return _hamming_distance_compute(correct, total) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 92faca200d0aa..8c59bb4991cab 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -169,8 +169,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch [1, 1, 0]], dtype=torch.int32) """ zeros = torch.zeros_like(prob_tensor) - topk_tensor = zeros.scatter(1, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) - + topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) return topk_tensor.int() diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 017438269bdbf..7b28e07c894dd 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -1,9 +1,13 @@ +from functools import partial + import numpy as np import pytest import torch -from sklearn.metrics import accuracy_score +from sklearn.metrics import accuracy_score as sk_accuracy -from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.metrics.functional import accuracy +from pytorch_lightning.metrics.classification.helpers import _input_format_classification from tests.metrics.classification.inputs import ( _binary_inputs, _binary_prob_inputs, @@ -13,101 +17,155 @@ _multidim_multiclass_prob_inputs, _multilabel_inputs, _multilabel_prob_inputs, + _multilabel_multidim_prob_inputs, + _multilabel_multidim_inputs, ) from tests.metrics.utils import THRESHOLD, MetricTester torch.manual_seed(42) -def _sk_accuracy_binary_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +def _sk_accuracy(preds, target, subset_accuracy): + sk_preds, sk_target, mode = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + if mode == "multi-dim multi-class" and not subset_accuracy: + sk_preds, sk_target = np.transpose(sk_preds, (0, 2, 1)), np.transpose(sk_target, (0, 2, 1)) + sk_preds, sk_target = sk_preds.reshape(-1, sk_preds.shape[2]), sk_target.reshape(-1, sk_target.shape[2]) + elif mode == mode == "multi-dim multi-class" and subset_accuracy: + return np.all(sk_preds == sk_target, axis=(1, 2)).mean() + elif mode == "multi-label" and not subset_accuracy: + sk_preds, sk_target = sk_preds.reshape(-1), sk_target.reshape(-1) -def _sk_accuracy_binary(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + return sk_accuracy(y_true=sk_target, y_pred=sk_preds) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +@pytest.mark.parametrize( + "preds, target, subset_accuracy", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, False), + (_binary_inputs.preds, _binary_inputs.target, False), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, True), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, False), + (_multilabel_inputs.preds, _multilabel_inputs.target, True), + (_multilabel_inputs.preds, _multilabel_inputs.target, False), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, False), + (_multiclass_inputs.preds, _multiclass_inputs.target, False), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, False), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, True), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, False), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, True), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, True), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target, False), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, True), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target, False), + ], +) +class TestAccuracies(MetricTester): + @pytest.mark.parametrize("ddp", [False, True]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_accuracy_class(self, ddp, dist_sync_on_step, preds, target, subset_accuracy): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=Accuracy, + sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), + dist_sync_on_step=dist_sync_on_step, + metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, + ) -def _sk_accuracy_multilabel_prob(preds, target): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() + def test_accuracy_fn(self, preds, target, subset_accuracy): + self.run_functional_metric_test( + preds, + target, + metric_functional=accuracy, + sk_metric=partial(_sk_accuracy, subset_accuracy=subset_accuracy), + metric_args={"threshold": THRESHOLD, "subset_accuracy": subset_accuracy}, + ) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +_l1to4 = [0.1, 0.2, 0.3, 0.4] +_l1to4t3 = np.array([_l1to4, _l1to4, _l1to4]) +_l1to4t3_mc = [_l1to4t3.T, _l1to4t3.T, _l1to4t3.T] -def _sk_accuracy_multilabel(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +# The preds in these examples always put highest probability on class 3, second highest on class 2, +# third highest on class 1, and lowest on class 0 +_topk_preds_mc = torch.tensor([_l1to4t3, _l1to4t3]).float() +_topk_target_mc = torch.tensor([[1, 2, 3], [2, 1, 0]]) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +# This is like for MC case, but one sample in each batch is sabotaged with 0 class prediction :) +_topk_preds_mdmc = torch.tensor([_l1to4t3_mc, _l1to4t3_mc]).float() +_topk_target_mdmc = torch.tensor([[[1, 1, 0], [2, 2, 2], [3, 3, 3]], [[2, 2, 0], [1, 1, 1], [0, 0, 0]]]) -def _sk_accuracy_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() +# Replace with a proper sk_metric test once sklearn 0.24 hits :) +@pytest.mark.parametrize( + "preds, target, exp_result, k, subset_accuracy", + [ + (_topk_preds_mc, _topk_target_mc, 1 / 6, 1, False), + (_topk_preds_mc, _topk_target_mc, 3 / 6, 2, False), + (_topk_preds_mc, _topk_target_mc, 5 / 6, 3, False), + (_topk_preds_mc, _topk_target_mc, 1 / 6, 1, True), + (_topk_preds_mc, _topk_target_mc, 3 / 6, 2, True), + (_topk_preds_mc, _topk_target_mc, 5 / 6, 3, True), + (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, False), + (_topk_preds_mdmc, _topk_target_mdmc, 8 / 18, 2, False), + (_topk_preds_mdmc, _topk_target_mdmc, 13 / 18, 3, False), + (_topk_preds_mdmc, _topk_target_mdmc, 1 / 6, 1, True), + (_topk_preds_mdmc, _topk_target_mdmc, 2 / 6, 2, True), + (_topk_preds_mdmc, _topk_target_mdmc, 3 / 6, 3, True), + ], +) +def test_topk_accuracy(preds, target, exp_result, k, subset_accuracy): + topk = Accuracy(top_k=k, subset_accuracy=subset_accuracy) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + for batch in range(preds.shape[0]): + topk(preds[batch], target[batch]) + assert topk.compute() == exp_result -def _sk_accuracy_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + # Test functional + total_samples = target.shape[0] * target.shape[1] - return accuracy_score(y_true=sk_target, y_pred=sk_preds) + preds = preds.view(total_samples, 4, -1) + target = target.view(total_samples, -1) + assert accuracy(preds, target, top_k=k, subset_accuracy=subset_accuracy) == exp_result -def _sk_accuracy_multidim_multiclass_prob(preds, target): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +# Only MC and MDMC with probs input type should be accepted for top_k +@pytest.mark.parametrize( + "preds, target", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), + ], +) +def test_topk_accuracy_wrong_input_types(preds, target): + topk = Accuracy(top_k=1) + with pytest.raises(ValueError): + topk(preds[0], target[0]) -def _sk_accuracy_multidim_multiclass(preds, target): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + with pytest.raises(ValueError): + accuracy(preds[0], target[0], top_k=1) - return accuracy_score(y_true=sk_target, y_pred=sk_preds) +@pytest.mark.parametrize("top_k, threshold", [(0, 0.5), (None, 1.5)]) +def test_wrong_params(top_k, threshold): + preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target -def test_accuracy_invalid_shape(): with pytest.raises(ValueError): - acc = Accuracy() - acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) + acc = Accuracy(threshold=threshold, top_k=top_k) + acc(preds, target) + acc.compute() - -@pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("dist_sync_on_step", [True, False]) -@pytest.mark.parametrize( - "preds, target, sk_metric", - [ - (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_accuracy_binary_prob), - (_binary_inputs.preds, _binary_inputs.target, _sk_accuracy_binary), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_accuracy_multilabel_prob), - (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_accuracy_multilabel), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_accuracy_multiclass_prob), - (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_accuracy_multiclass), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_accuracy_multidim_multiclass_prob, - ), - (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, _sk_accuracy_multidim_multiclass), - ], -) -class TestAccuracy(MetricTester): - def test_accuracy(self, ddp, dist_sync_on_step, preds, target, sk_metric): - self.run_class_metric_test( - ddp=ddp, - preds=preds, - target=target, - metric_class=Accuracy, - sk_metric=sk_metric, - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, - ) + with pytest.raises(ValueError): + accuracy(preds, target, threshold=threshold, top_k=top_k) diff --git a/tests/metrics/classification/test_hamming_distance.py b/tests/metrics/classification/test_hamming_distance.py new file mode 100644 index 0000000000000..73c2abe771fae --- /dev/null +++ b/tests/metrics/classification/test_hamming_distance.py @@ -0,0 +1,82 @@ +import pytest +import torch +from sklearn.metrics import hamming_loss as sk_hamming_loss + +from pytorch_lightning.metrics import HammingDistance +from pytorch_lightning.metrics.functional import hamming_distance +from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from tests.metrics.classification.inputs import ( + _binary_inputs, + _binary_prob_inputs, + _multiclass_inputs, + _multiclass_prob_inputs, + _multidim_multiclass_inputs, + _multidim_multiclass_prob_inputs, + _multilabel_inputs, + _multilabel_prob_inputs, + _multilabel_multidim_prob_inputs, + _multilabel_multidim_inputs, +) +from tests.metrics.utils import THRESHOLD, MetricTester + +torch.manual_seed(42) + + +def _sk_hamming_loss(preds, target): + sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) + sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() + sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) + + return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) + + +@pytest.mark.parametrize( + "preds, target", + [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target), + (_binary_inputs.preds, _binary_inputs.target), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target), + (_multilabel_inputs.preds, _multilabel_inputs.target), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target), + (_multiclass_inputs.preds, _multiclass_inputs.target), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target), + (_multilabel_multidim_prob_inputs.preds, _multilabel_multidim_prob_inputs.target), + (_multilabel_multidim_inputs.preds, _multilabel_multidim_inputs.target), + ], +) +class TestHammingDistance(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [False, True]) + def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=HammingDistance, + sk_metric=_sk_hamming_loss, + dist_sync_on_step=dist_sync_on_step, + metric_args={"threshold": THRESHOLD}, + ) + + def test_hamming_distance_fn(self, preds, target): + self.run_functional_metric_test( + preds, + target, + metric_functional=hamming_distance, + sk_metric=_sk_hamming_loss, + metric_args={"threshold": THRESHOLD}, + ) + + +@pytest.mark.parametrize("threshold", [1.5]) +def test_wrong_params(threshold): + preds, target = _multiclass_prob_inputs.preds, _multiclass_prob_inputs.target + + with pytest.raises(ValueError): + ham_dist = HammingDistance(threshold=threshold) + ham_dist(preds, target) + ham_dist.compute() + + with pytest.raises(ValueError): + hamming_distance(preds, target, threshold=threshold) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index c4d01d282fa57..30d3f06707301 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -22,6 +22,8 @@ torch.manual_seed(42) # Some additional inputs to test on +_ml_prob_half = Input(_ml_prob.preds.half(), _ml_prob.target) + _mc_prob_2cls_preds = rand(NUM_BATCHES, BATCH_SIZE, 2) _mc_prob_2cls_preds /= _mc_prob_2cls_preds.sum(dim=2, keepdim=True) _mc_prob_2cls = Input(_mc_prob_2cls_preds, randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))) @@ -133,6 +135,8 @@ def _mlmd_prob_to_mc_preds_tr(x): (_mdmc_prob_many_dims, None, None, 2, "multi-dim multi-class", _top2_rshp2, _onehot_rshp1), ########################### # Test some special cases + # Make sure that half precision works, i.e. is converted to full precision + (_ml_prob_half, None, None, None, "multi-label", lambda x: _ml_preds_tr(x.float()), _rshp1), # Binary as multiclass (_bin, None, None, None, "multi-class", _onehot2, _onehot2), # Binary probs as multiclass diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 4a2d690ec08a3..ec2c8bf387005 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -3,7 +3,6 @@ import pytest import torch from sklearn.metrics import ( - accuracy_score as sk_accuracy, jaccard_score as sk_jaccard_score, precision_score as sk_precision, recall_score as sk_recall, @@ -14,7 +13,6 @@ from pytorch_lightning.metrics.functional.classification import ( stat_scores, stat_scores_multiple_classes, - accuracy, precision, recall, dice_score, @@ -28,7 +26,6 @@ @pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ - pytest.param(sk_accuracy, accuracy, False, id='accuracy'), pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'), pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), @@ -164,39 +161,6 @@ def test_stat_scores_multiclass(pred, target, reduction, assert torch.allclose(torch.tensor(expected_support).to(sup), sup) -def test_multilabel_accuracy(): - # Dense label indicator matrix format - y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) - y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) - - assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.])) - assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.])) - assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.])) - - # num_classes does not match extracted number from input we expect a warning - with pytest.warns(RuntimeWarning, - match=r'You have set .* number of classes which is' - r' different from predicted (.*) and' - r' target (.*) number of classes'): - _ = accuracy(y2, torch.zeros_like(y2), num_classes=3) - - -def test_accuracy(): - pred = torch.tensor([0, 1, 2, 3]) - target = torch.tensor([0, 1, 2, 2]) - acc = accuracy(pred, target) - - assert acc.item() == 0.75 - - pred = torch.tensor([0, 1, 2, 2]) - target = torch.tensor([0, 1, 1, 3]) - acc = accuracy(pred, target) - - assert acc.item() == 0.50 - - @pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) diff --git a/tests/metrics/regression/test_ssim.py b/tests/metrics/regression/test_ssim.py index f581188e89fce..8bb304850e3f2 100644 --- a/tests/metrics/regression/test_ssim.py +++ b/tests/metrics/regression/test_ssim.py @@ -53,9 +53,7 @@ def _sk_metric(preds, target, data_range, multichannel): class TestSSIM(MetricTester): atol = 6e-5 - # TODO: for some reason this test hangs with ddp=True - # @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("ddp", [False]) + @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) def test_ssim(self, preds, target, multichannel, ddp, dist_sync_on_step): self.run_class_metric_test( diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index c607a466b2068..4bd6608ce3fcf 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -11,6 +11,11 @@ from pytorch_lightning.metrics import Metric +try: + set_start_method("spawn") +except RuntimeError: + pass + NUM_PROCESSES = 2 NUM_BATCHES = 10 BATCH_SIZE = 32 @@ -165,10 +170,7 @@ def setup_class(self): """Setup the metric class. This will spawn the pool of workers that are used for metric testing and setup_ddp """ - try: - set_start_method("spawn") - except RuntimeError: - pass + self.poolSize = NUM_PROCESSES self.pool = Pool(processes=self.poolSize) self.pool.starmap(setup_ddp, [(rank, self.poolSize) for rank in range(self.poolSize)])