diff --git a/.github/actions/unittesting/action.yml b/.github/actions/unittesting/action.yml index e69fce887b9..cf45d95e5e6 100644 --- a/.github/actions/unittesting/action.yml +++ b/.github/actions/unittesting/action.yml @@ -57,7 +57,7 @@ runs: - name: Unittests working-directory: ./tests - run: python -m pytest ${{ inputs.dirs }} --cov=torchmetrics --junitxml="$PYTEST_ARTEFACT.xml" --durations=50 ${{ inputs.test-timeout }} + run: python -m pytest -v --maxfail=5 ${{ inputs.dirs }} --cov=torchmetrics --junitxml="$PYTEST_ARTEFACT.xml" --durations=50 ${{ inputs.test-timeout }} shell: ${{ inputs.shell-type }} - name: Upload pytest results diff --git a/CHANGELOG.md b/CHANGELOG.md index 3cac9ed61cf..99c25a3ea51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Classification refactor ( + [#1054](https://github.com/Lightning-AI/metrics/pull/1054), +) + - Changed update in `FID` metric to be done in a online fashion to save memory ([#1199](https://github.com/PyTorchLightning/metrics/pull/1199)) diff --git a/docs/source/classification/confusion_matrix.rst b/docs/source/classification/confusion_matrix.rst index a1cc43fdfe9..bde2207e043 100644 --- a/docs/source/classification/confusion_matrix.rst +++ b/docs/source/classification/confusion_matrix.rst @@ -12,11 +12,56 @@ Confusion Matrix Module Interface ________________ +ConfusionMatrix +^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.ConfusionMatrix :noindex: +BinaryConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryConfusionMatrix + :noindex: + :exclude-members: update, compute + +MulticlassConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassConfusionMatrix + :noindex: + :exclude-members: update, compute + +MultilabelConfusionMatrix +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelConfusionMatrix + :noindex: + :exclude-members: update, compute + Functional Interface ____________________ +confusion_matrix +^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.confusion_matrix :noindex: + +binary_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_confusion_matrix + :noindex: + +multiclass_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_confusion_matrix + :noindex: + +multilabel_confusion_matrix +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_confusion_matrix + :noindex: diff --git a/docs/source/classification/stat_scores.rst b/docs/source/classification/stat_scores.rst index 809c3106948..382e6048534 100644 --- a/docs/source/classification/stat_scores.rst +++ b/docs/source/classification/stat_scores.rst @@ -12,11 +12,56 @@ Stat Scores Module Interface ________________ +StatScores +^^^^^^^^^^ + .. autoclass:: torchmetrics.StatScores :noindex: +BinaryStatScores +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryStatScores + :noindex: + :exclude-members: update, compute + +MulticlassStatScores +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassStatScores + :noindex: + :exclude-members: update, compute + +MultilabelStatScores +^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelStatScores + :noindex: + :exclude-members: update, compute + Functional Interface ____________________ +stat_scores +^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.stat_scores :noindex: + +binary_stat_scores +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_stat_scores + :noindex: + +multiclass_stat_scores +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_stat_scores + :noindex: + +multilabel_stat_scores +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_stat_scores + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 483a4d786b7..1c9cff4c0ab 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -26,6 +26,8 @@ ROC, Accuracy, AveragePrecision, + BinaryConfusionMatrix, + BinaryStatScores, BinnedAveragePrecision, BinnedPrecisionRecallCurve, BinnedRecallAtFixedPrecision, @@ -43,6 +45,10 @@ LabelRankingAveragePrecision, LabelRankingLoss, MatthewsCorrCoef, + MulticlassConfusionMatrix, + MulticlassStatScores, + MultilabelConfusionMatrix, + MultilabelStatScores, Precision, PrecisionRecallCurve, Recall, @@ -126,6 +132,9 @@ "CHRFScore", "CohenKappa", "ConfusionMatrix", + "BinaryConfusionMatrix", + "MulticlassConfusionMatrix", + "MultilabelConfusionMatrix", "CosineSimilarity", "CoverageError", "Dice", @@ -187,6 +196,9 @@ "SQuAD", "StructuralSimilarityIndexMeasure", "StatScores", + "BinaryStatScores", + "MulticlassStatScores", + "MultilabelStatScores", "SumMetric", "SymmetricMeanAbsolutePercentageError", "TranslationEditRate", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 70ae4d5179c..2575b9f9ba2 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -20,7 +20,12 @@ from torchmetrics.classification.binned_precision_recall import BinnedRecallAtFixedPrecision # noqa: F401 from torchmetrics.classification.calibration_error import CalibrationError # noqa: F401 from torchmetrics.classification.cohen_kappa import CohenKappa # noqa: F401 -from torchmetrics.classification.confusion_matrix import ConfusionMatrix # noqa: F401 +from torchmetrics.classification.confusion_matrix import ( # noqa: F401 + BinaryConfusionMatrix, + ConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) from torchmetrics.classification.dice import Dice # noqa: F401 from torchmetrics.classification.f_beta import F1Score, FBetaScore # noqa: F401 from torchmetrics.classification.hamming import HammingDistance # noqa: F401 @@ -37,4 +42,9 @@ ) from torchmetrics.classification.roc import ROC # noqa: F401 from torchmetrics.classification.specificity import Specificity # noqa: F401 -from torchmetrics.classification.stat_scores import StatScores # noqa: F401 +from torchmetrics.classification.stat_scores import ( # noqa: F401 + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) diff --git a/src/torchmetrics/classification/confusion_matrix.py b/src/torchmetrics/classification/confusion_matrix.py index a847b04044b..8ee276723c7 100644 --- a/src/torchmetrics/classification/confusion_matrix.py +++ b/src/torchmetrics/classification/confusion_matrix.py @@ -15,11 +15,308 @@ import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update +from torchmetrics.functional.classification.confusion_matrix import ( + _binary_confusion_matrix_arg_validation, + _binary_confusion_matrix_compute, + _binary_confusion_matrix_format, + _binary_confusion_matrix_tensor_validation, + _binary_confusion_matrix_update, + _confusion_matrix_compute, + _confusion_matrix_update, + _multiclass_confusion_matrix_arg_validation, + _multiclass_confusion_matrix_compute, + _multiclass_confusion_matrix_format, + _multiclass_confusion_matrix_tensor_validation, + _multiclass_confusion_matrix_update, + _multilabel_confusion_matrix_arg_validation, + _multilabel_confusion_matrix_compute, + _multilabel_confusion_matrix_format, + _multilabel_confusion_matrix_tensor_validation, + _multilabel_confusion_matrix_update, +) from torchmetrics.metric import Metric +class BinaryConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for binary tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics import BinaryConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> metric = BinaryConfusionMatrix() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics import BinaryConfusionMatrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> metric = BinaryConfusionMatrix() + >>> metric(preds, target) + tensor([[2, 0], + [1, 1]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _binary_confusion_matrix_tensor_validation(preds, target, self.ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, self.threshold, self.ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + + Returns an [2,2] matrix. + """ + return _binary_confusion_matrix_compute(self.confmat, self.normalize) + + +class MulticlassConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for multiclass tasks. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (pred is integer tensor): + >>> from torchmetrics import MulticlassConfusionMatrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics import MulticlassConfusionMatrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassConfusionMatrix(num_classes=3) + >>> metric(preds, target) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + self.num_classes = num_classes + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_classes, num_classes, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multiclass_confusion_matrix_tensor_validation(preds, target, self.num_classes, self.ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, self.ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, self.num_classes) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + + Returns an [num_classes, num_classes] matrix. + """ + return _multiclass_confusion_matrix_compute(self.confmat, self.normalize) + + +class MultilabelConfusionMatrix(Metric): + r""" + Computes the `confusion matrix`_ for multilabel tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + num_classes: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics import MultilabelConfusionMatrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics import MultilabelConfusionMatrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelConfusionMatrix(num_labels=3) + >>> metric(preds, target) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["none", "true", "pred", "all"]] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + self.num_labels = num_labels + self.threshold = threshold + self.ignore_index = ignore_index + self.normalize = normalize + self.validate_args = validate_args + + self.add_state("confmat", torch.zeros(num_labels, 2, 2, dtype=torch.long), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multilabel_confusion_matrix_tensor_validation(preds, target, self.num_labels, self.ignore_index) + preds, target = _multilabel_confusion_matrix_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + confmat = _multilabel_confusion_matrix_update(preds, target, self.num_labels) + self.confmat += confmat + + def compute(self) -> Tensor: + """Computes confusion matrix. + + Returns an [num_labels,2,2] matrix. + """ + return _multilabel_confusion_matrix_compute(self.confmat, self.normalize) + + +# -------------------------- Old stuff -------------------------- + + class ConfusionMatrix(Metric): r"""Computes the `confusion matrix`_. diff --git a/src/torchmetrics/classification/stat_scores.py b/src/torchmetrics/classification/stat_scores.py index eca2150d63b..50bb6ed0837 100644 --- a/src/torchmetrics/classification/stat_scores.py +++ b/src/torchmetrics/classification/stat_scores.py @@ -11,16 +11,468 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple, Union import torch from torch import Tensor - -from torchmetrics.functional.classification.stat_scores import _stat_scores_compute, _stat_scores_update +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_compute, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_compute, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_compute, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _stat_scores_compute, + _stat_scores_update, +) from torchmetrics.metric import Metric from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +class _AbstractStatScores(Metric): + # define common functions + def _create_state(self, size: int, multidim_average: str) -> None: + """Initialize the states for the different statistics.""" + default: Union[Callable[[], list], Callable[[], Tensor]] + if multidim_average == "samplewise": + default = lambda: [] + dist_reduce_fx = "cat" + else: + default = lambda: torch.zeros(size, dtype=torch.long) + dist_reduce_fx = "sum" + self.add_state("tp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fp", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("tn", default(), dist_reduce_fx=dist_reduce_fx) + self.add_state("fn", default(), dist_reduce_fx=dist_reduce_fx) + + def _update_state(self, tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor) -> None: + """Update states depending on multidim_average argument.""" + if self.multidim_average == "samplewise": + self.tp.append(tp) + self.fp.append(fp) + self.tn.append(tn) + self.fn.append(fn) + else: + self.tp += tp + self.fp += fp + self.tn += tn + self.fn += fn + + def _final_state(self) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Final aggregation in case of list states.""" + tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp + fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp + tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn + fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn + return tp, fp, tn, fn + + +class BinaryStatScores(_AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for binary tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (multidim tensors): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_stat_scores(preds, target, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + self.threshold = threshold + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self._create_state(1, multidim_average) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index) + preds, target = _binary_stat_scores_format(preds, target, self.threshold, self.ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, self.multidim_average) + self._update_state(tp, fp, tn, fn) + + def compute(self) -> Tensor: + """Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on the ``multidim_average`` parameter: + + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + """ + tp, fp, tn, fn = self._final_state() + return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average) + + +class MulticlassStatScores(_AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multiclass tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise') + tensor([[3, 3, 9, 3, 6], + [2, 4, 8, 4, 6]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[[2, 1, 3, 0, 2], + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + self.num_classes = num_classes + self.top_k = top_k + self.average = average + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self._create_state(num_classes, multidim_average) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multiclass_stat_scores_tensor_validation( + preds, target, self.num_classes, self.multidim_average, self.ignore_index + ) + preds, target = _multiclass_stat_scores_format(preds, target, self.top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update( + preds, target, self.num_classes, self.top_k, self.multidim_average, self.ignore_index + ) + self._update_state(tp, fp, tn, fn) + + def compute(self) -> Tensor: + """Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + """ + tp, fp, tn, fn = self._final_state() + return _multiclass_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) + + +class MultilabelStatScores(_AbstractStatScores): + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multilabel tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[[1, 1, 0, 0, 1], + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) + + """ + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + self.num_labels = num_labels + self.threshold = threshold + self.average = average + self.multidim_average = multidim_average + self.ignore_index = ignore_index + self.validate_args = validate_args + + self._create_state(num_labels, multidim_average) + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + """Update state with predictions and targets. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + """ + if self.validate_args: + _multilabel_stat_scores_tensor_validation( + preds, target, self.num_labels, self.multidim_average, self.ignore_index + ) + preds, target = _multilabel_stat_scores_format( + preds, target, self.num_labels, self.threshold, self.ignore_index + ) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, self.multidim_average) + self._update_state(tp, fp, tn, fn) + + def compute(self) -> Tensor: + """Computes the final statistics. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + """ + tp, fp, tn, fn = self._final_state() + return _multilabel_stat_scores_compute(tp, fp, tn, fn, self.average, self.multidim_average) + + +# -------------------------- Old stuff -------------------------- + + class StatScores(Metric): r"""Computes the number of true positives, false positives, true negatives, false negatives. Related to `Type I and Type II errors`_ and the `confusion matrix`_. diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index f66d55ba459..fdd9f25aaba 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -20,7 +20,12 @@ from torchmetrics.functional.classification.average_precision import average_precision from torchmetrics.functional.classification.calibration_error import calibration_error from torchmetrics.functional.classification.cohen_kappa import cohen_kappa -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix +from torchmetrics.functional.classification.confusion_matrix import ( + binary_confusion_matrix, + confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) from torchmetrics.functional.classification.dice import dice, dice_score from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score from torchmetrics.functional.classification.hamming import hamming_distance @@ -37,7 +42,12 @@ ) from torchmetrics.functional.classification.roc import roc from torchmetrics.functional.classification.specificity import specificity -from torchmetrics.functional.classification.stat_scores import stat_scores +from torchmetrics.functional.classification.stat_scores import ( + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, + stat_scores, +) from torchmetrics.functional.image.d_lambda import spectral_distortion_index from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis from torchmetrics.functional.image.gradients import image_gradients @@ -171,4 +181,11 @@ "word_error_rate", "word_information_lost", "word_information_preserved", +] + [ + "binary_confusion_matrix", + "multiclass_confusion_matrix", + "multilabel_confusion_matrix", + "binary_stat_scores", + "multiclass_stat_scores", + "multilabel_stat_scores", ] diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 70f777b56e0..f1efe75fb5b 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -17,7 +17,12 @@ from torchmetrics.functional.classification.average_precision import average_precision # noqa: F401 from torchmetrics.functional.classification.calibration_error import calibration_error # noqa: F401 from torchmetrics.functional.classification.cohen_kappa import cohen_kappa # noqa: F401 -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix # noqa: F401 +from torchmetrics.functional.classification.confusion_matrix import ( # noqa: F401 + binary_confusion_matrix, + confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) from torchmetrics.functional.classification.dice import dice, dice_score # noqa: F401 from torchmetrics.functional.classification.f_beta import f1_score, fbeta_score # noqa: F401 from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401 @@ -34,4 +39,9 @@ ) from torchmetrics.functional.classification.roc import roc # noqa: F401 from torchmetrics.functional.classification.specificity import specificity # noqa: F401 -from torchmetrics.functional.classification.stat_scores import stat_scores # noqa: F401 +from torchmetrics.functional.classification.stat_scores import ( # noqa: F401 + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, + stat_scores, +) diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 26e47a0633b..53df4956889 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -98,7 +98,7 @@ def _average_precision_compute( if preds.ndim == target.ndim and target.ndim > 1: weights = target.sum(dim=0).float() else: - weights = _bincount(target, minlength=num_classes).float() + weights = _bincount(target, minlength=max(num_classes, 2)).float() weights = weights / torch.sum(weights) else: weights = None diff --git a/src/torchmetrics/functional/classification/confusion_matrix.py b/src/torchmetrics/functional/classification/confusion_matrix.py index 362276dc146..79dfaae6a94 100644 --- a/src/torchmetrics/functional/classification/confusion_matrix.py +++ b/src/torchmetrics/functional/classification/confusion_matrix.py @@ -11,15 +11,577 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Tuple import torch from torch import Tensor +from typing_extensions import Literal -from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.data import _bincount +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.data import _bincount, _movedim from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.prints import rank_zero_warn + + +def _confusion_matrix_reduce( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduce an un-normalized confusion matrix + Args: + confmat: un-normalized confusion matrix + normalize: normalization method. + - `"true"` will divide by the sum of the column dimension. + - `"pred"` will divide by the sum of the row dimension. + - `"all"` will divide by the sum of the full matrix + - `"none"` or `None` will apply no reduction + + Returns: + Normalized confusion matrix + """ + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Argument `normalize` needs to one of the following: {allowed_normalize}") + if normalize is not None and normalize != "none": + confmat = confmat.float() if not confmat.is_floating_point() else confmat + if normalize == "true": + confmat = confmat / confmat.sum(axis=-1, keepdim=True) + elif normalize == "pred": + confmat = confmat / confmat.sum(axis=-2, keepdim=True) + elif normalize == "all": + confmat = confmat / confmat.sum(axis=[-2, -1], keepdim=True) + + nan_elements = confmat[torch.isnan(confmat)].nelement() + if nan_elements: + confmat[torch.isnan(confmat)] = 0 + rank_zero_warn(f"{nan_elements} NaN values found in confusion matrix have been replaced with zeros.") + return confmat + + +def _binary_confusion_matrix_arg_validation( + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(threshold, float) and not (0 <= threshold <= 1): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _binary_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + """ + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains {0,1} values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains {0,1} values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + +def _binary_confusion_matrix_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - Remove all datapoints that should be ignored + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + """ + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + return preds, target + + +def _binary_confusion_matrix_update(preds: Tensor, target: Tensor) -> Tensor: + """Computes the bins to update the confusion matrix with.""" + unique_mapping = (target * 2 + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=4) + return bins.reshape(2, 2) + + +def _binary_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ + return _confusion_matrix_reduce(confmat, normalize) + + +def binary_confusion_matrix( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the `confusion matrix`_ for binary tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A ``[2, 2]`` tensor + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0, 1, 0, 0]) + >>> binary_confusion_matrix(preds, target) + tensor([[2, 0], + [1, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_confusion_matrix + >>> target = torch.tensor([1, 1, 0, 0]) + >>> preds = torch.tensor([0.35, 0.85, 0.48, 0.01]) + >>> binary_confusion_matrix(preds, target) + tensor([[2, 0], + [1, 1]]) + """ + if validate_args: + _binary_confusion_matrix_arg_validation(threshold, ignore_index, normalize) + _binary_confusion_matrix_tensor_validation(preds, target, ignore_index) + preds, target = _binary_confusion_matrix_format(preds, target, threshold, ignore_index) + confmat = _binary_confusion_matrix_update(preds, target) + return _binary_confusion_matrix_compute(confmat, normalize) + + +def _multiclass_confusion_matrix_arg_validation( + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multiclass_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_classes: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - if preds and target have same number of dims, then all dimensions should match + - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1} + - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1} + """ + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + num_unique_values = len(torch.unique(target)) + if ignore_index is None: + check = num_unique_values > num_classes + else: + check = num_unique_values > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found " + f"{num_unique_values} in `target`." + ) + + if not preds.is_floating_point(): + num_unique_values = len(torch.unique(preds)) + if num_unique_values > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {num_unique_values} in `preds`." + ) + + +def _multiclass_confusion_matrix_format( + preds: Tensor, target: Tensor, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - Applies argmax if preds have one more dimension than target + - Remove all datapoints that should be ignored + """ + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1: + preds = preds.argmax(dim=1) + + preds = preds.flatten() + target = target.flatten() + + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + + return preds, target + + +def _multiclass_confusion_matrix_update(preds: Tensor, target: Tensor, num_classes: int) -> Tensor: + """Computes the bins to update the confusion matrix with.""" + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + return bins.reshape(num_classes, num_classes) + + +def _multiclass_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ + return _confusion_matrix_reduce(confmat, normalize) + + +def multiclass_confusion_matrix( + preds: Tensor, + target: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the `confusion matrix`_ for multiclass tasks. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A ``[num_classes, num_classes]`` tensor + + Example (pred is integer tensor): + >>> from torchmetrics.functional import multiclass_confusion_matrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_confusion_matrix(preds, target, num_classes=3) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + + Example (pred is float tensor): + >>> from torchmetrics.functional import multiclass_confusion_matrix + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_confusion_matrix(preds, target, num_classes=3) + tensor([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + """ + if validate_args: + _multiclass_confusion_matrix_arg_validation(num_classes, ignore_index, normalize) + _multiclass_confusion_matrix_tensor_validation(preds, target, num_classes, ignore_index) + preds, target = _multiclass_confusion_matrix_format(preds, target, ignore_index) + confmat = _multiclass_confusion_matrix_update(preds, target, num_classes) + return _multiclass_confusion_matrix_compute(confmat, normalize) + + +def _multilabel_confusion_matrix_arg_validation( + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, +) -> None: + """Validate non tensor input. + + - ``num_labels`` should be an int larger than 1 + - ``threshold`` has to be a float in the [0,1] range + - ``ignore_index`` has to be None or int + - ``normalize`` has to be "true" | "pred" | "all" | "none" | None + """ + if not isinstance(num_labels, int) or num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + allowed_normalize = ("true", "pred", "all", "none", None) + if normalize not in allowed_normalize: + raise ValueError(f"Expected argument `normalize` to be one of {allowed_normalize}, but got {normalize}.") + + +def _multilabel_confusion_matrix_tensor_validation( + preds: Tensor, target: Tensor, num_labels: int, ignore_index: Optional[int] = None +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - the second dimension of both tensors need to be equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + """ + # Check that they have same shape + _check_same_shape(preds, target) + + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + +def _multilabel_confusion_matrix_format( + preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all elements that should be ignored with negative numbers for later filtration + """ + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + preds = preds > threshold + preds = _movedim(preds, 1, -1).reshape(-1, num_labels) + target = _movedim(target, 1, -1).reshape(-1, num_labels) + + if ignore_index is not None: + preds = preds.clone() + target = target.clone() + # Make sure that when we map, it will always result in a negative number that we can filter away + # Each label correspond to a 2x2 matrix = 4 elements per label + idx = target == ignore_index + preds[idx] = -4 * num_labels + target[idx] = -4 * num_labels + + return preds, target + + +def _multilabel_confusion_matrix_update(preds: Tensor, target: Tensor, num_labels: int) -> Tensor: + """Computes the bins to update the confusion matrix with.""" + unique_mapping = ((2 * target + preds) + 4 * torch.arange(num_labels, device=preds.device)).flatten() + unique_mapping = unique_mapping[unique_mapping >= 0] + bins = _bincount(unique_mapping, minlength=4 * num_labels) + return bins.reshape(num_labels, 2, 2) + + +def _multilabel_confusion_matrix_compute( + confmat: Tensor, normalize: Optional[Literal["true", "pred", "all", "none"]] = None +) -> Tensor: + """Reduces the confusion matrix to it's final form. + + Normalization technique can be chosen by ``normalize``. + """ + return _confusion_matrix_reduce(confmat, normalize) + + +def multilabel_confusion_matrix( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + ignore_index: Optional[int] = None, + normalize: Optional[Literal["true", "pred", "all", "none"]] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the `confusion matrix`_ for multilabel tasks. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + Additional dimension ``...`` will be flattened into the batch dimension. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + normalize: Normalization mode for confusion matrix. Choose from: + + - ``None`` or ``'none'``: no normalization (default) + - ``'true'``: normalization over the targets (most commonly used) + - ``'pred'``: normalization over the predictions + - ``'all'``: normalization over the whole matrix + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A ``[num_labels, 2, 2]`` tensor + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_confusion_matrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_confusion_matrix(preds, target, num_labels=3) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_confusion_matrix + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_confusion_matrix(preds, target, num_labels=3) + tensor([[[1, 0], [0, 1]], + [[1, 0], [1, 0]], + [[0, 1], [0, 1]]]) + """ + if validate_args: + _multilabel_confusion_matrix_arg_validation(num_labels, threshold, ignore_index, normalize) + _multilabel_confusion_matrix_tensor_validation(preds, target, num_labels, ignore_index) + preds, target = _multilabel_confusion_matrix_format(preds, target, num_labels, threshold, ignore_index) + confmat = _multilabel_confusion_matrix_update(preds, target, num_labels) + return _multilabel_confusion_matrix_compute(confmat, normalize) + + +# -------------------------- Old stuff -------------------------- def _confusion_matrix_update( @@ -182,5 +744,12 @@ def confusion_matrix( [[0, 1], [0, 1]]]) """ + rank_zero_warn( + "`torchmetrics.functional.confusion_matrix` have been deprecated in v0.10 in favor of" + "`torchmetrics.functional.binary_confusion_matrix`, `torchmetrics.functional.multiclass_confusion_matrix`" + "and `torchmetrics.functional.multilabel_confusion_matrix`. Please upgrade to the version that matches" + "your problem (API may have changed). This function will be removed v0.11.", + DeprecationWarning, + ) confmat = _confusion_matrix_update(preds, target, num_classes, threshold, multilabel) return _confusion_matrix_compute(confmat, normalize) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index b3cb7786e49..176e95a3ff1 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -15,9 +15,802 @@ import torch from torch import Tensor, tensor +from typing_extensions import Literal -from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.checks import _check_same_shape, _input_format_classification +from torchmetrics.utilities.data import _bincount, _movedim, select_topk from torchmetrics.utilities.enums import AverageMethod, DataType, MDMCAverageMethod +from torchmetrics.utilities.prints import rank_zero_warn + + +def _binary_stat_scores_arg_validation( + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``threshold`` has to be a float in the [0,1] range + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(threshold, float) and not (0 <= threshold <= 1): + raise ValueError(f"Expected argument `threshold` to be a float in the [0,1] range, but got {threshold}.") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _binary_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 2 dimensional + """ + # Check that they have same shape + _check_same_shape(preds, target) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + f"Detected the following values in `target`: {unique_values} but expected only" + f" the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + f"Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since `preds` is a label tensor." + ) + + if multidim_average != "global" and preds.ndim < 2: + raise ValueError("Expected input to be atleast 2D when multidim_average is set to `samplewise`") + + +def _binary_stat_scores_format( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all datapoints that should be ignored with negative values + """ + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + # preds is logits, convert with sigmoid + preds = preds.sigmoid() + preds = preds > threshold + + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + + if ignore_index is not None: + idx = target == ignore_index + target = target.clone() + target[idx] = -1 + + return preds, target + + +def _binary_stat_scores_update( + preds: Tensor, + target: Tensor, + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics.""" + sum_dim = [0, 1] if multidim_average == "global" else 1 + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn + + +def _binary_stat_scores_compute( + tp: Tensor, fp: Tensor, tn: Tensor, fn: Tensor, multidim_average: Literal["global", "samplewise"] = "global" +) -> Tensor: + """Stack statistics and compute support also.""" + return torch.stack([tp, fp, tn, fn, tp + fn], dim=0 if multidim_average == "global" else 1).squeeze() + + +def binary_stat_scores( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for binary tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on the ``multidim_average`` parameter: + + - If ``multidim_average`` is set to ``global``, the shape will be ``(5,)`` + - If ``multidim_average`` is set to ``samplewise``, the shape will be ``(N, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_stat_scores(preds, target) + tensor([2, 1, 2, 1, 3]) + + Example (multidim tensors): + >>> from torchmetrics.functional import binary_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_stat_scores(preds, target, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _binary_stat_scores_compute(tp, fp, tn, fn, multidim_average) + + +def _multiclass_stat_scores_arg_validation( + num_classes: int, + top_k: int = 1, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_classes`` has to be a int larger than 1 + - ``top_k`` has to be an int larger than 0 but no larger than number of classes + - ``average`` has to be "micro" | "macro" | "weighted" | "none" + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_classes, int) or num_classes < 2: + raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if not isinstance(top_k, int) and top_k < 1: + raise ValueError(f"Expected argument `top_k` to be an integer larger than or equal to 1, but got {top_k}") + if top_k > num_classes: + raise ValueError( + f"Expected argument `top_k` to be smaller or equal to `num_classes` but got {top_k} and {num_classes}" + ) + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multiclass_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_classes: int, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate tensor input. + + - if target has one more dimension than preds, then all dimensions except for preds.shape[1] should match + exactly. preds.shape[1] should have size equal to number of classes + - if preds and target have same number of dims, then all dimensions should match + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 2 dimensional in the + int case and 3 dimensional in the float case + - all values in target tensor that are not ignored have to be {0, ..., num_classes - 1} + - if pred tensor is not floating point, then all values also have to be in {0, ..., num_classes - 1} + """ + if preds.ndim == target.ndim + 1: + if not preds.is_floating_point(): + raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") + if preds.shape[1] != num_classes: + raise ValueError( + "If `preds` have one dimension more than `target`, `preds.shape[1]` should be" + " equal to number of classes." + ) + if preds.shape[2:] != target.shape[1:]: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should be" + " (N, C, ...), and the shape of `target` should be (N, ...)." + ) + if multidim_average != "global" and preds.ndim < 3: + raise ValueError( + "If `preds` have one dimension more than `target`, the shape of `preds` should " + " atleast 3D when multidim_average is set to `samplewise`" + ) + + elif preds.ndim == target.ndim: + if preds.shape != target.shape: + raise ValueError( + "The `preds` and `target` should have the same shape,", + f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", + ) + if multidim_average != "global" and preds.ndim < 2: + raise ValueError( + "When `preds` and `target` have the same shape, the shape of `preds` should " + " atleast 2D when multidim_average is set to `samplewise`" + ) + else: + raise ValueError( + "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" + " and `preds` should be (N, C, ...)." + ) + + num_unique_values = len(torch.unique(target)) + if ignore_index is None: + check = num_unique_values > num_classes + else: + check = num_unique_values > num_classes + 1 + if check: + raise RuntimeError( + "Detected more unique values in `target` than `num_classes`. Expected only " + f"{num_classes if ignore_index is None else num_classes + 1} but found" + f"{num_unique_values} in `target`." + ) + + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if len(unique_values) > num_classes: + raise RuntimeError( + "Detected more unique values in `preds` than `num_classes`. Expected only " + f"{num_classes} but found {len(unique_values)} in `preds`." + ) + + +def _multiclass_stat_scores_format( + preds: Tensor, + target: Tensor, + top_k: int = 1, +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format except if ``top_k`` is not 1. + + - Applies argmax if preds have one more dimension than target + - Flattens additional dimensions + """ + # Apply argmax if we have one more dimension + if preds.ndim == target.ndim + 1 and top_k == 1: + preds = preds.argmax(dim=1) + if top_k != 1: + preds = preds.reshape(*preds.shape[:2], -1) + else: + preds = preds.reshape(preds.shape[0], -1) + target = target.reshape(target.shape[0], -1) + return preds, target + + +def _multiclass_stat_scores_update( + preds: Tensor, + target: Tensor, + num_classes: int, + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics. + + - If ``multidim_average`` is equal to samplewise or ``top_k`` is not 1, we transform both preds and + target into one hot format. + - Else we calculate statistics by first calculating the confusion matrix and afterwards deriving the + statistics from that + - Remove all datapoints that should be ignored. Depending on if ``ignore_index`` is in the set of labels + or outside we have do use different augmentation stategies when one hot encoding. + """ + if multidim_average == "samplewise" or top_k != 1: + ignore_in = 0 <= ignore_index <= num_classes - 1 if ignore_index is not None else None + if ignore_index is not None and not ignore_in: + preds = preds.clone() + target = target.clone() + idx = target == ignore_index + preds[idx] = num_classes + target[idx] = num_classes + + if top_k > 1: + preds_oh = _movedim(select_topk(preds, topk=top_k, dim=1), 1, -1) + else: + preds_oh = torch.nn.functional.one_hot( + preds, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) + target_oh = torch.nn.functional.one_hot( + target, num_classes + 1 if ignore_index is not None and not ignore_in else num_classes + ) + if ignore_index is not None: + if 0 <= ignore_index <= num_classes - 1: + target_oh[target == ignore_index, :] = -1 + else: + preds_oh = preds_oh[..., :-1] + target_oh = target_oh[..., :-1] + target_oh[target == num_classes, :] = -1 + sum_dim = [0, 1] if multidim_average == "global" else [1] + tp = ((target_oh == preds_oh) & (target_oh == 1)).sum(sum_dim) + fn = ((target_oh != preds_oh) & (target_oh == 1)).sum(sum_dim) + fp = ((target_oh != preds_oh) & (target_oh == 0)).sum(sum_dim) + tn = ((target_oh == preds_oh) & (target_oh == 0)).sum(sum_dim) + return tp, fp, tn, fn + else: + preds = preds.flatten() + target = target.flatten() + if ignore_index is not None: + idx = target != ignore_index + preds = preds[idx] + target = target[idx] + unique_mapping = (target * num_classes + preds).to(torch.long) + bins = _bincount(unique_mapping, minlength=num_classes**2) + confmat = bins.reshape(num_classes, num_classes) + tp = confmat.diag() + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + return tp, fp, tn, fn + + +def _multiclass_stat_scores_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + """Stack statistics and compute support also. + + Applies average strategy afterwards. + """ + res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) + sum_dim = 0 if multidim_average == "global" else 1 + if average == "micro": + return res.sum(sum_dim) + elif average == "macro": + return res.float().mean(sum_dim) + elif average == "weighted": + weight = tp + fn + if multidim_average == "global": + return (res * (weight / weight.sum()).reshape(*weight.shape, 1)).sum(sum_dim) + else: + return (res * (weight / weight.sum(-1, keepdim=True)).reshape(*weight.shape, 1)).sum(sum_dim) + elif average is None or average == "none": + return res + + +def multiclass_stat_scores( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multiclass tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_stat_scores(preds, target, num_classes=3) + tensor([3, 1, 7, 1, 4]) + >>> multiclass_stat_scores(preds, target, num_classes=3, average=None) + tensor([[1, 0, 2, 1, 2], + [1, 1, 2, 0, 1], + [1, 0, 3, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multiclass_stat_scores + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise') + tensor([[3, 3, 9, 3, 6], + [2, 4, 8, 4, 6]]) + >>> multiclass_stat_scores(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[[2, 1, 3, 0, 2], + [0, 1, 3, 2, 2], + [1, 1, 3, 1, 2]], + [[0, 1, 4, 1, 1], + [1, 1, 2, 2, 3], + [1, 2, 2, 1, 2]]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _multiclass_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) + + +def _multilabel_stat_scores_arg_validation( + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, +) -> None: + """Validate non tensor input. + + - ``num_labels`` should be an int larger than 1 + - ``threshold`` has to be a float in the [0,1] range + - ``average`` has to be "micro" | "macro" | "weighted" | "none" + - ``multidim_average`` has to be either "global" or "samplewise" + - ``ignore_index`` has to be None or int + """ + if not isinstance(num_labels, int) or num_labels < 2: + raise ValueError(f"Expected argument `num_labels` to be an integer larger than 1, but got {num_labels}") + if not isinstance(threshold, float): + raise ValueError(f"Expected argument `threshold` to be a float, but got {threshold}.") + allowed_average = ("micro", "macro", "weighted", "none", None) + if average not in allowed_average: + raise ValueError(f"Expected argument `average` to be one of {allowed_average}, but got {average}") + allowed_multidim_average = ("global", "samplewise") + if multidim_average not in allowed_multidim_average: + raise ValueError( + f"Expected argument `multidim_average` to be one of {allowed_multidim_average}, but got {multidim_average}" + ) + if ignore_index is not None and not isinstance(ignore_index, int): + raise ValueError(f"Expected argument `ignore_index` to either be `None` or an integer, but got {ignore_index}") + + +def _multilabel_stat_scores_tensor_validation( + preds: Tensor, + target: Tensor, + num_labels: int, + multidim_average: str, + ignore_index: Optional[int] = None, +) -> None: + """Validate tensor input. + + - tensors have to be of same shape + - the second dimension of both tensors need to be equal to the number of labels + - all values in target tensor that are not ignored have to be in {0, 1} + - if pred tensor is not floating point, then all values also have to be in {0, 1} + - if ``multidim_average`` is set to ``samplewise`` preds tensor needs to be atleast 3 dimensional + """ + # Check that they have same shape + _check_same_shape(preds, target) + + if preds.shape[1] != num_labels: + raise ValueError( + "Expected both `target.shape[1]` and `preds.shape[1]` to be equal to the number of labels" + f" but got {preds.shape[1]} and expected {num_labels}" + ) + + # Check that target only contains [0,1] values or value in ignore_index + unique_values = torch.unique(target) + if ignore_index is None: + check = torch.any((unique_values != 0) & (unique_values != 1)) + else: + check = torch.any((unique_values != 0) & (unique_values != 1) & (unique_values != ignore_index)) + if check: + raise RuntimeError( + "Detected the following values in `target`: {unique_values} but expected only" + " the following values {[0,1] + [] if ignore_index is None else [ignore_index]}." + ) + + # If preds is label tensor, also check that it only contains [0,1] values + if not preds.is_floating_point(): + unique_values = torch.unique(preds) + if torch.any((unique_values != 0) & (unique_values != 1)): + raise RuntimeError( + "Detected the following values in `preds`: {unique_values} but expected only" + " the following values [0,1] since preds is a label tensor." + ) + + if multidim_average != "global" and preds.ndim < 3: + raise ValueError("Expected input to be atleast 3D when multidim_average is set to `samplewise`") + + +def _multilabel_stat_scores_format( + preds: Tensor, target: Tensor, num_labels: int, threshold: float = 0.5, ignore_index: Optional[int] = None +) -> Tuple[Tensor, Tensor]: + """Convert all input to label format. + + - If preds tensor is floating point, applies sigmoid if pred tensor not in [0,1] range + - If preds tensor is floating point, thresholds afterwards + - Mask all elements that should be ignored with negative numbers for later filtration + """ + if preds.is_floating_point(): + if not torch.all((0 <= preds) * (preds <= 1)): + preds = preds.sigmoid() + preds = preds > threshold + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + + if ignore_index is not None: + idx = target == ignore_index + target = target.clone() + target[idx] = -1 + + return preds, target + + +def _multilabel_stat_scores_update( + preds: Tensor, target: Tensor, multidim_average: Literal["global", "samplewise"] = "global" +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Computes the statistics.""" + sum_dim = [0, -1] if multidim_average == "global" else [-1] + tp = ((target == preds) & (target == 1)).sum(sum_dim).squeeze() + fn = ((target != preds) & (target == 1)).sum(sum_dim).squeeze() + fp = ((target != preds) & (target == 0)).sum(sum_dim).squeeze() + tn = ((target == preds) & (target == 0)).sum(sum_dim).squeeze() + return tp, fp, tn, fn + + +def _multilabel_stat_scores_compute( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + """Stack statistics and compute support also. + + Applies average strategy afterwards. + """ + res = torch.stack([tp, fp, tn, fn, tp + fn], dim=-1) + sum_dim = 0 if multidim_average == "global" else 1 + if average == "micro": + return res.sum(sum_dim) + elif average == "macro": + return res.float().mean(sum_dim) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(*w.shape, 1)).sum(sum_dim) + elif average is None or average == "none": + return res + + +def multilabel_stat_scores( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r""" + Computes the number of true positives, false positives, true negatives, false negatives and the support + for multilabel tasks. Related to `Type I and Type II errors`_. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The metric returns a tensor of shape ``(..., 5)``, where the last dimension corresponds + to ``[tp, fp, tn, fn, sup]`` (``sup`` stands for support and equals ``tp + fn``). The shape + depends on ``average`` and ``multidim_average`` parameters: + + - If ``multidim_average`` is set to ``global`` + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(5,)`` + - If ``average=None/'none'``, the shape will be ``(C, 5)`` + + - If ``multidim_average`` is set to ``samplewise`` + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N, 5)`` + - If ``average=None/'none'``, the shape will be ``(N, C, 5)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_stat_scores(preds, target, num_labels=3) + tensor([2, 1, 2, 1, 3]) + >>> multilabel_stat_scores(preds, target, num_labels=3, average=None) + tensor([[1, 0, 1, 0, 1], + [0, 0, 1, 1, 1], + [1, 1, 0, 0, 1]]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multilabel_stat_scores + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise') + tensor([[2, 3, 0, 1, 3], + [0, 2, 1, 3, 3]]) + >>> multilabel_stat_scores(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[[1, 1, 0, 0, 1], + [1, 1, 0, 0, 1], + [0, 1, 0, 1, 1]], + [[0, 0, 0, 2, 2], + [0, 2, 0, 0, 0], + [0, 0, 1, 1, 1]]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _multilabel_stat_scores_compute(tp, fp, tn, fn, average, multidim_average) + + +# -------------------------- Old stuff -------------------------- def _del_column(data: Tensor, idx: int) -> Tensor: @@ -416,6 +1209,13 @@ def stat_scores( tensor([2, 2, 6, 2, 4]) """ + rank_zero_warn( + "`torchmetrics.functional.stat_scores` have been deprecated in v0.10 in favor of" + "`torchmetrics.functional.binary_stat_scores`, `torchmetrics.functional.multiclass_stat_scores`" + "and `torchmetrics.functional.multilabel_stat_scores`. Please upgrade to the version that matches" + "your problem (API may have changed). This function will be removed v0.11.", + DeprecationWarning, + ) if reduce not in ["micro", "macro", "samples"]: raise ValueError(f"The `reduce` {reduce} is not valid.") diff --git a/src/torchmetrics/utilities/checks.py b/src/torchmetrics/utilities/checks.py index 54d15b3d455..dfcb2922147 100644 --- a/src/torchmetrics/utilities/checks.py +++ b/src/torchmetrics/utilities/checks.py @@ -32,7 +32,9 @@ def _check_for_empty_tensors(preds: Tensor, target: Tensor) -> bool: def _check_same_shape(preds: Tensor, target: Tensor) -> None: """Check that predictions and target have the same shape, else raise error.""" if preds.shape != target.shape: - raise RuntimeError("Predictions and targets are expected to have the same shape") + raise RuntimeError( + f"Predictions and targets are expected to have the same shape, but got {preds.shape} and {target.shape}." + ) def _basic_input_validation( diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index f9cd329fb11..92df978e96c 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -261,15 +261,15 @@ def _bincount(x: Tensor, minlength: Optional[int] = None) -> Tensor: Returns: Number of occurrences for each unique element in x """ - if x.is_cuda and deterministic() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: - if minlength is None: - minlength = len(torch.unique(x)) + if minlength is None: + minlength = len(torch.unique(x)) + if deterministic() or _TORCH_GREATER_EQUAL_1_12 and x.is_mps: output = torch.zeros(minlength, device=x.device, dtype=torch.long) for i in range(minlength): output[i] = (x == i).sum() return output - else: - return torch.bincount(x, minlength=minlength) + z = torch.zeros(minlength, device=x.device, dtype=x.dtype) + return z.index_add_(0, x, torch.ones_like(x)) def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: @@ -277,3 +277,11 @@ def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: if tensor1.dtype != tensor2.dtype: tensor2 = tensor2.to(dtype=tensor1.dtype) return torch.allclose(tensor1, tensor2) + + +def _movedim(tensor: Tensor, dim1: int, dim2: int) -> tensor: + if _TORCH_GREATER_EQUAL_1_7: + return torch.movedim(tensor, dim1, dim2) + if dim2 >= 0: + dim2 += 1 + return tensor.unsqueeze(dim2).transpose(dim2, dim1).squeeze(dim1) diff --git a/tests/unittests/classification/inputs.py b/tests/unittests/classification/inputs.py index ff88b452638..bb3762a0edd 100644 --- a/tests/unittests/classification/inputs.py +++ b/tests/unittests/classification/inputs.py @@ -14,12 +14,22 @@ from collections import namedtuple import torch +from torch import Tensor from unittests.helpers import seed_all from unittests.helpers.testers import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, NUM_CLASSES seed_all(1) + +def _inv_sigmoid(x: Tensor) -> Tensor: + return (x / (1 - x)).log() + + +def _logsoftmax(x: Tensor, dim: int = -1) -> Tensor: + return torch.nn.functional.log_softmax(x, dim) + + Input = namedtuple("Input", ["preds", "target"]) _input_binary_prob = Input( @@ -60,6 +70,86 @@ target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), ) +_binary_cases = ( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE))), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), +) + + +_multiclass_cases = ( + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES).softmax(-1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), -1), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + ), + Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM).softmax(-2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), + Input( + preds=_logsoftmax(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), -2), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)), + ), +) + + +_multilabel_cases = ( + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + ), + Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), + Input( + preds=_inv_sigmoid(torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM)), + ), +) + # Generate edge multilabel edge case, where nothing matches (scores are undefined) __temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) __temp_target = abs(__temp_preds - 1) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index f54ca95c7d1..b52ba3a216f 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -12,194 +12,478 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Dict import numpy as np import pytest import torch +from scipy.special import expit as sigmoid from sklearn.metrics import confusion_matrix as sk_confusion_matrix -from sklearn.metrics import multilabel_confusion_matrix as sk_multilabel_confusion_matrix - -from torchmetrics import JaccardIndex -from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.confusion_matrix import confusion_matrix -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob + +from torchmetrics.classification.confusion_matrix import ( + BinaryConfusionMatrix, + MulticlassConfusionMatrix, + MultilabelConfusionMatrix, +) +from torchmetrics.functional.classification.confusion_matrix import ( + binary_confusion_matrix, + multiclass_confusion_matrix, + multilabel_confusion_matrix, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) -def _sk_cm_binary_prob(preds, target, normalize=None): - sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_binary(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() - - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) - - -def _sk_cm_multilabel_prob(preds, target, normalize=None): - sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8) - sk_target = target.numpy() - - cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) - if normalize is not None: - if normalize == "true": - cm = cm / cm.sum(axis=1, keepdims=True) - elif normalize == "pred": - cm = cm / cm.sum(axis=0, keepdims=True) - elif normalize == "all": - cm = cm / cm.sum() - cm[np.isnan(cm)] = 0 - return cm - - -def _sk_cm_multilabel(preds, target, normalize=None): - sk_preds = preds.numpy() - sk_target = target.numpy() +def _sk_confusion_matrix_binary(preds, target, normalize=None, ignore_index=None): + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1], normalize=normalize) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_binary_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) - cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) - if normalize is not None: - if normalize == "true": - cm = cm / cm.sum(axis=1, keepdims=True) - elif normalize == "pred": - cm = cm / cm.sum(axis=0, keepdims=True) - elif normalize == "all": - cm = cm / cm.sum() - cm[np.isnan(cm)] = 0 - return cm + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_binary_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_binary, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "threshold": THRESHOLD, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + def test_binary_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + ) -def _sk_cm_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_confusion_matrix_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryConfusionMatrix, + metric_functional=binary_confusion_matrix, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -def _sk_cm_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() +def _sk_confusion_matrix_multiclass(preds, target, normalize=None, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + preds = np.argmax(preds, axis=1) + preds = preds.flatten() + target = target.flatten() - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) -def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): - sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() - sk_target = target.view(-1).numpy() +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multiclass_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_multiclass, normalize=normalize, ignore_index=ignore_index), + metric_args={ + "num_classes": NUM_CLASSES, + "normalize": normalize, + "ignore_index": ignore_index, + }, + ) + def test_multiclass_confusion_matrix_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + ) -def _sk_cm_multidim_multiclass(preds, target, normalize=None): - sk_preds = preds.view(-1).numpy() - sk_target = target.view(-1).numpy() + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_confusion_matrix_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassConfusionMatrix, + metric_functional=multiclass_confusion_matrix, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) -@pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) -@pytest.mark.parametrize( - "preds, target, sk_metric, num_classes, multilabel", - [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), - (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), - (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), - (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), - (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), - (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), - (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), - (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False), - ], -) -class TestConfusionMatrix(MetricTester): +def _sk_confusion_matrix_multilabel(preds, target, normalize=None, ignore_index=None): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) + target = np.moveaxis(target, 1, -1).reshape((-1, target.shape[1])) + confmat = [] + for i in range(preds.shape[1]): + p, t = preds[:, i], target[:, i] + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + confmat.append(sk_confusion_matrix(t, p, normalize=normalize, labels=[0, 1])) + return np.stack(confmat, axis=0) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelConfusionMatrix(MetricTester): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_confusion_matrix( - self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step - ): + def test_multilabel_confusion_matrix(self, input, ddp, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=ConfusionMatrix, - sk_metric=partial(sk_metric, normalize=normalize), - dist_sync_on_step=dist_sync_on_step, + metric_class=MultilabelConfusionMatrix, + sk_metric=partial(_sk_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) - def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): + @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) + @pytest.mark.parametrize("ignore_index", [None, -1, 0]) + def test_multilabel_confusion_matrix_functional(self, input, normalize, ignore_index): + preds, target = input + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) self.run_functional_metric_test( preds=preds, target=target, - metric_functional=confusion_matrix, - sk_metric=partial(sk_metric, normalize=normalize), + metric_functional=multilabel_confusion_matrix, + sk_metric=partial(_sk_confusion_matrix_multilabel, normalize=normalize, ignore_index=ignore_index), metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, + "num_labels": NUM_CLASSES, "normalize": normalize, - "multilabel": multilabel, + "ignore_index": ignore_index, }, ) - def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): + def test_multilabel_confusion_matrix_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=ConfusionMatrix, - metric_functional=confusion_matrix, - metric_args={ - "num_classes": num_classes, - "threshold": THRESHOLD, - "normalize": normalize, - "multilabel": multilabel, - }, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_confusion_matrix_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_confusion_matrix_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelConfusionMatrix, + metric_functional=multilabel_confusion_matrix, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, ) -def test_warning_on_nan(tmpdir): +def test_warning_on_nan(): preds = torch.randint(3, size=(20,)) target = torch.randint(3, size=(20,)) with pytest.warns( UserWarning, - match=".* nan values found in confusion matrix have been replaced with zeros.", + match=".* NaN values found in confusion matrix have been replaced with zeros.", ): - confusion_matrix(preds, target, num_classes=5, normalize="true") - - -@pytest.mark.parametrize( - "metric_args", - [ - {"num_classes": 1, "normalize": "true"}, - {"num_classes": 1, "normalize": "pred"}, - {"num_classes": 1, "normalize": "all"}, - {"num_classes": 1, "normalize": "none"}, - {"num_classes": 1, "normalize": None}, - ], -) -def test_provide_superclass_kwargs(metric_args: Dict[str, Any]): - """Test instantiating subclasses with superclass arguments as kwargs.""" - JaccardIndex(**metric_args) + multiclass_confusion_matrix(preds, target, num_classes=5, normalize="true") + + +# -------------------------- Old stuff -------------------------- + +# def _sk_cm_binary_prob(preds, target, normalize=None): +# sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_binary(preds, target, normalize=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multilabel_prob(preds, target, normalize=None): +# sk_preds = (preds.numpy() >= THRESHOLD).astype(np.uint8) +# sk_target = target.numpy() + +# cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) +# if normalize is not None: +# if normalize == "true": +# cm = cm / cm.sum(axis=1, keepdims=True) +# elif normalize == "pred": +# cm = cm / cm.sum(axis=0, keepdims=True) +# elif normalize == "all": +# cm = cm / cm.sum() +# cm[np.isnan(cm)] = 0 +# return cm + + +# def _sk_cm_multilabel(preds, target, normalize=None): +# sk_preds = preds.numpy() +# sk_target = target.numpy() + +# cm = sk_multilabel_confusion_matrix(y_true=sk_target, y_pred=sk_preds) +# if normalize is not None: +# if normalize == "true": +# cm = cm / cm.sum(axis=1, keepdims=True) +# elif normalize == "pred": +# cm = cm / cm.sum(axis=0, keepdims=True) +# elif normalize == "all": +# cm = cm / cm.sum() +# cm[np.isnan(cm)] = 0 +# return cm + + +# def _sk_cm_multiclass_prob(preds, target, normalize=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multiclass(preds, target, normalize=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multidim_multiclass_prob(preds, target, normalize=None): +# sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# def _sk_cm_multidim_multiclass(preds, target, normalize=None): +# sk_preds = preds.view(-1).numpy() +# sk_target = target.view(-1).numpy() + +# return sk_confusion_matrix(y_true=sk_target, y_pred=sk_preds, normalize=normalize) + + +# @pytest.mark.parametrize("normalize", ["true", "pred", "all", None]) +# @pytest.mark.parametrize( +# "preds, target, sk_metric, num_classes, multilabel", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_cm_binary_prob, 2, False), +# (_input_binary_logits.preds, _input_binary_logits.target, _sk_cm_binary_prob, 2, False), +# (_input_binary.preds, _input_binary.target, _sk_cm_binary, 2, False), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), +# (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_cm_multilabel_prob, NUM_CLASSES, True), +# (_input_mlb.preds, _input_mlb.target, _sk_cm_multilabel, NUM_CLASSES, True), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), +# (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_cm_multiclass_prob, NUM_CLASSES, False), +# (_input_mcls.preds, _input_mcls.target, _sk_cm_multiclass, NUM_CLASSES, False), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_cm_multidim_multiclass_prob, NUM_CLASSES, False), +# (_input_mdmc.preds, _input_mdmc.target, _sk_cm_multidim_multiclass, NUM_CLASSES, False), +# ], +# ) +# class TestConfusionMatrix(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_confusion_matrix( +# self, normalize, preds, target, sk_metric, num_classes, multilabel, ddp, dist_sync_on_step +# ): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=ConfusionMatrix, +# sk_metric=partial(sk_metric, normalize=normalize), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# "normalize": normalize, +# "multilabel": multilabel, +# }, +# ) + +# def test_confusion_matrix_functional(self, normalize, preds, target, sk_metric, num_classes, multilabel): +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=confusion_matrix, +# sk_metric=partial(sk_metric, normalize=normalize), +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# "normalize": normalize, +# "multilabel": multilabel, +# }, +# ) + +# def test_confusion_matrix_differentiability(self, normalize, preds, target, sk_metric, num_classes, multilabel): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=ConfusionMatrix, +# metric_functional=confusion_matrix, +# metric_args={ +# "num_classes": num_classes, +# "threshold": THRESHOLD, +# "normalize": normalize, +# "multilabel": multilabel, +# }, +# ) + + +# def test_warning_on_nan(tmpdir): +# preds = torch.randint(3, size=(20,)) +# target = torch.randint(3, size=(20,)) + +# with pytest.warns( +# UserWarning, +# match=".* nan values found in confusion matrix have been replaced with zeros.", +# ): +# confusion_matrix(preds, target, num_classes=5, normalize="true") diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 75e52f66b4a..d61284834a7 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -12,334 +12,786 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, Optional import numpy as np import pytest import torch -from sklearn.metrics import multilabel_confusion_matrix -from torch import Tensor, tensor - -from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores -from torchmetrics.functional import stat_scores -from torchmetrics.utilities.checks import _input_format_classification -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob, _input_multiclass -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mcls -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix + +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.stat_scores import ( + binary_stat_scores, + multiclass_stat_scores, + multilabel_stat_scores, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) -def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None): - # todo: `mdmc_reduce` is unused - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - sk_preds, sk_target = preds.numpy(), target.numpy() +def _sk_stat_scores_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + tn, fp, fn, tp = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() + return np.array([tp, fp, tn, fn, tp + fn]) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, fn, tp = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() + res.append(np.array([tp, fp, tn, fn, tp + fn])) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryStatScores(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_stat_scores(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: - sk_preds = np.delete(sk_preds, ignore_index, 1) - sk_target = np.delete(sk_target, ignore_index, 1) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryStatScores, + sk_metric=partial(_sk_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + ) - if preds.shape[1] == 1 and reduce == "samples": - sk_target = sk_target.T - sk_preds = sk_preds.T + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_stat_scores_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") - sk_stats = multilabel_confusion_matrix( - sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 - ) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_stat_scores, + sk_metric=partial(_sk_stat_scores_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, + ) - if preds.shape[1] == 1 and reduce != "samples": - sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] - else: - sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + def test_binary_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + ) - if reduce == "micro": - sk_stats = sk_stats.sum(axis=0, keepdims=True) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_stat_scores_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_stat_scores_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryStatScores, + metric_functional=binary_stat_scores, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - if reduce == "micro": - sk_stats = sk_stats[0] - if reduce == "macro" and ignore_index is not None and preds.shape[1]: - sk_stats[ignore_index, :] = -1 +def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + preds = preds.numpy().flatten() + target = target.numpy().flatten() + + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + res = np.stack([tp, fp, tn, fn, tp + fn], 1) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res - return sk_stats + else: + preds = preds.numpy() + target = target.numpy() + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + r = np.stack([tp, fp, tn, fn, tp + fn], 1) + if average == "micro": + res.append(r.sum(0)) + elif average == "macro": + res.append(r.mean(0)) + elif average == "weighted": + w = tp + fn + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) + elif average is None or average == "none": + res.append(r) + return np.stack(res, 0) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassStatScores(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassStatScores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) -def _sk_stat_scores_mdim_mcls( - preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold -): - preds, target, _ = _input_format_classification( - preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multiclass_stat_scores_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_stat_scores, + sk_metric=partial( + _sk_stat_scores_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) - return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold) - if mdmc_reduce == "samplewise": - scores = [] + def test_multiclass_stat_scores_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + ) - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_stat_scores_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_stat_scores_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassStatScores, + metric_functional=multiclass_stat_scores, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - scores.append(np.expand_dims(scores_i, 0)) - return np.concatenate(scores) +_mc_k_target = torch.tensor([0, 1, 2]) +_mc_k_preds = torch.tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) @pytest.mark.parametrize( - "reduce, mdmc_reduce, num_classes, inputs, ignore_index", + "k, preds, target, average, expected", [ - ["unknown", None, None, _input_binary, None], - ["micro", "unknown", None, _input_binary, None], - ["macro", None, None, _input_binary, None], - ["micro", None, None, _input_mdmc_prob, None], - ["micro", None, None, _input_binary_prob, 0], - ["micro", None, None, _input_mcls_prob, NUM_CLASSES], - ["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES], + (1, _mc_k_preds, _mc_k_target, "micro", torch.tensor([2, 1, 5, 1, 3])), + (2, _mc_k_preds, _mc_k_target, "micro", torch.tensor([3, 3, 3, 0, 3])), + (1, _mc_k_preds, _mc_k_target, None, torch.tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), + (2, _mc_k_preds, _mc_k_target, None, torch.tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), ], ) -def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): - """Test a combination of parameters that are invalid and should raise an error. - - This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when - ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` - when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. - """ - with pytest.raises(ValueError): - stat_scores( - inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index - ) +def test_top_k_multiclass(k, preds, target, average, expected): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassStatScores(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) - with pytest.raises(ValueError): - sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) - sts(inputs.preds[0], inputs.target[0]) + assert torch.allclose(class_metric.compute().long(), expected.T) + assert torch.allclose( + multiclass_stat_scores(preds, target, top_k=k, average=average, num_classes=3).long(), expected.T + ) -@pytest.mark.parametrize("ignore_index", [None, 0]) -@pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) -@pytest.mark.parametrize( - "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold", - [ - (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0), - (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5), - (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5), - (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5), - (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), - (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0), - (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "samplewise", NUM_CLASSES, None, None, 0.0), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - _sk_stat_scores_mdim_mcls, - "samplewise", - NUM_CLASSES, - None, - None, - 0.0, - ), - (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0), - ( - _input_mdmc_prob.preds, - _input_mdmc_prob.target, - _sk_stat_scores_mdim_mcls, - "global", - NUM_CLASSES, - None, - None, - 0.0, - ), - ], -) -class TestStatScores(MetricTester): - # DDP tests temporarily disabled due to hanging issues - @pytest.mark.parametrize("ddp", [False]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - @pytest.mark.parametrize("dtype", [torch.float, torch.double]) - def test_stat_scores_class( - self, - ddp: bool, - dist_sync_on_step: bool, - dtype: torch.dtype, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if preds.is_floating_point(): - preds = preds.to(dtype) - if target.is_floating_point(): - target = target.to(dtype) +def _sk_stat_scores_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if multidim_average == "global": + stat_scores = [] + for i in range(preds.shape[1]): + p, t = preds[:, i].flatten(), target[:, i].flatten() + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() + stat_scores.append(np.array([tp, fp, tn, fn, tp + fn])) + res = np.stack(stat_scores, axis=0) + + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = res[:, 0] + res[:, 3] + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + else: + stat_scores = [] + for i in range(preds.shape[0]): + scores = [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + scores.append(np.array([tp, fp, tn, fn, tp + fn])) + stat_scores.append(np.stack(scores, 1)) + res = np.stack(stat_scores, 0) + if average == "micro": + return res.sum(-1) + elif average == "macro": + return res.mean(-1) + elif average == "weighted": + w = res[:, 0, :] + res[:, 3, :] + return (res * (w / w.sum())[:, np.newaxis]).sum(-1) + elif average is None or average == "none": + return np.moveaxis(res, 1, -1) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelStatScores(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_stat_scores(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=StatScores, + metric_class=MultilabelStatScores, sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - multiclass=multiclass, + _sk_stat_scores_multilabel, ignore_index=ignore_index, - top_k=top_k, - threshold=threshold, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "top_k": top_k, + "multidim_average": multidim_average, + "average": average, }, ) - def test_stat_scores_fn( - self, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_stat_scores_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=stat_scores, + preds=preds, + target=target, + metric_functional=multilabel_stat_scores, sk_metric=partial( - sk_fn, - reduce=reduce, - mdmc_reduce=mdmc_reduce, - num_classes=num_classes, - multiclass=multiclass, + _sk_stat_scores_multilabel, ignore_index=ignore_index, - top_k=top_k, - threshold=threshold, + multidim_average=multidim_average, + average=average, ), metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, "ignore_index": ignore_index, - "top_k": top_k, + "multidim_average": multidim_average, + "average": average, }, ) - def test_stat_scores_differentiability( - self, - sk_fn: Callable, - preds: Tensor, - target: Tensor, - reduce: str, - mdmc_reduce: Optional[str], - num_classes: Optional[int], - multiclass: Optional[bool], - ignore_index: Optional[int], - top_k: Optional[int], - threshold: Optional[float], - ): - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - + def test_multilabel_stat_scores_differentiability(self, input): + preds, target = input self.run_differentiability_test( - preds, - target, - metric_module=StatScores, - metric_functional=stat_scores, - metric_args={ - "num_classes": num_classes, - "reduce": reduce, - "mdmc_reduce": mdmc_reduce, - "threshold": threshold, - "multiclass": multiclass, - "ignore_index": ignore_index, - "top_k": top_k, - }, + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_stat_scores_half_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -_mc_k_target = tensor([0, 1, 2]) -_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) - - -@pytest.mark.parametrize( - "k, preds, target, reduce, expected", - [ - (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), - (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), - (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), - (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), - (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), - (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), - (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), - (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), - ], -) -def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): - """A simple test to check that top_k works as expected.""" - - class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) - class_metric.update(preds, target) - - assert torch.equal(class_metric.compute(), expected.T) - assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_stat_scores_half_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelStatScores, + metric_functional=multilabel_stat_scores, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize( - "metric_args", - [ - {"reduce": "micro"}, - {"num_classes": 1, "reduce": "macro"}, - {"reduce": "samples"}, - {"mdmc_reduce": None}, - {"mdmc_reduce": "samplewise"}, - {"mdmc_reduce": "global"}, - ], -) -@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity]) -def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]): - """Test instantiating subclasses with superclass arguments as kwargs.""" - metric_cls(**metric_args) +# -------------------------- Old stuff -------------------------- + +# def _sk_stat_scores(preds, target, reduce, num_classes, multiclass, ignore_index, top_k, threshold, mdmc_reduce=None): +# # todo: `mdmc_reduce` is unused +# preds, target, _ = _input_format_classification( +# preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k +# ) +# sk_preds, sk_target = preds.numpy(), target.numpy() + +# if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: +# sk_preds = np.delete(sk_preds, ignore_index, 1) +# sk_target = np.delete(sk_target, ignore_index, 1) + +# if preds.shape[1] == 1 and reduce == "samples": +# sk_target = sk_target.T +# sk_preds = sk_preds.T + +# sk_stats = multilabel_confusion_matrix( +# sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 +# ) + +# if preds.shape[1] == 1 and reduce != "samples": +# sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] +# else: +# sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + +# if reduce == "micro": +# sk_stats = sk_stats.sum(axis=0, keepdims=True) + +# sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + +# if reduce == "micro": +# sk_stats = sk_stats[0] + +# if reduce == "macro" and ignore_index is not None and preds.shape[1]: +# sk_stats[ignore_index, :] = -1 + +# return sk_stats + + +# def _sk_stat_scores_mdim_mcls( +# preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k, threshold +# ): +# preds, target, _ = _input_format_classification( +# preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k +# ) + +# if mdmc_reduce == "global": +# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) +# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) + +# return _sk_stat_scores(preds, target, reduce, None, False, ignore_index, top_k, threshold) +# if mdmc_reduce == "samplewise": +# scores = [] + +# for i in range(preds.shape[0]): +# pred_i = preds[i, ...].T +# target_i = target[i, ...].T +# scores_i = _sk_stat_scores(pred_i, target_i, reduce, None, False, ignore_index, top_k, threshold) + +# scores.append(np.expand_dims(scores_i, 0)) + +# return np.concatenate(scores) + + +# @pytest.mark.parametrize( +# "reduce, mdmc_reduce, num_classes, inputs, ignore_index", +# [ +# ["unknown", None, None, _input_binary, None], +# ["micro", "unknown", None, _input_binary, None], +# ["macro", None, None, _input_binary, None], +# ["micro", None, None, _input_mdmc_prob, None], +# ["micro", None, None, _input_binary_prob, 0], +# ["micro", None, None, _input_mcls_prob, NUM_CLASSES], +# ["micro", None, NUM_CLASSES, _input_mcls_prob, NUM_CLASSES], +# ], +# ) +# def test_wrong_params(reduce, mdmc_reduce, num_classes, inputs, ignore_index): +# """Test a combination of parameters that are invalid and should raise an error. + +# This includes invalid ``reduce`` and ``mdmc_reduce`` parameter values, not setting ``num_classes`` when +# ``reduce='macro'`, not setting ``mdmc_reduce`` when inputs are multi-dim multi-class``, setting ``ignore_index`` +# when inputs are binary, as well as setting ``ignore_index`` to a value higher than the number of classes. +# """ +# with pytest.raises(ValueError): +# stat_scores( +# inputs.preds[0], inputs.target[0], reduce, mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index +# ) + +# with pytest.raises(ValueError): +# sts = StatScores(reduce=reduce, mdmc_reduce=mdmc_reduce, num_classes=num_classes, ignore_index=ignore_index) +# sts(inputs.preds[0], inputs.target[0]) + + +# @pytest.mark.parametrize("ignore_index", [None, 0]) +# @pytest.mark.parametrize("reduce", ["micro", "macro", "samples"]) +# @pytest.mark.parametrize( +# "preds, target, sk_fn, mdmc_reduce, num_classes, multiclass, top_k, threshold", +# [ +# (_input_binary_logits.preds, _input_binary_logits.target, _sk_stat_scores, None, 1, None, None, 0.0), +# (_input_binary_prob.preds, _input_binary_prob.target, _sk_stat_scores, None, 1, None, None, 0.5), +# (_input_binary.preds, _input_binary.target, _sk_stat_scores, None, 1, False, None, 0.5), +# (_input_mlb_logits.preds, _input_mlb_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), +# (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.5), +# (_input_mcls.preds, _input_mcls.target, _sk_stat_scores, None, NUM_CLASSES, False, None, 0.5), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.5), +# (_input_mcls_logits.preds, _input_mcls_logits.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), +# (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_stat_scores, None, NUM_CLASSES, None, 2, 0.0), +# (_input_multiclass.preds, _input_multiclass.target, _sk_stat_scores, None, NUM_CLASSES, None, None, 0.0), +# ( +# _input_mdmc.preds, +# _input_mdmc.target, +# _sk_stat_scores_mdim_mcls, +# "samplewise", +# NUM_CLASSES, +# None, +# None, +# 0.0 +# ), +# ( +# _input_mdmc_prob.preds, +# _input_mdmc_prob.target, +# _sk_stat_scores_mdim_mcls, +# "samplewise", +# NUM_CLASSES, +# None, +# None, +# 0.0, +# ), +# (_input_mdmc.preds, _input_mdmc.target, _sk_stat_scores_mdim_mcls, "global", NUM_CLASSES, None, None, 0.0), +# ( +# _input_mdmc_prob.preds, +# _input_mdmc_prob.target, +# _sk_stat_scores_mdim_mcls, +# "global", +# NUM_CLASSES, +# None, +# None, +# 0.0, +# ), +# ], +# ) +# class TestStatScores(MetricTester): +# # DDP tests temporarily disabled due to hanging issues +# @pytest.mark.parametrize("ddp", [False]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# @pytest.mark.parametrize("dtype", [torch.float, torch.double]) +# def test_stat_scores_class( +# self, +# ddp: bool, +# dist_sync_on_step: bool, +# dtype: torch.dtype, +# sk_fn: Callable, +# preds: Tensor, +# target: Tensor, +# reduce: str, +# mdmc_reduce: Optional[str], +# num_classes: Optional[int], +# multiclass: Optional[bool], +# ignore_index: Optional[int], +# top_k: Optional[int], +# threshold: Optional[float], +# ): +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# if preds.is_floating_point(): +# preds = preds.to(dtype) +# if target.is_floating_point(): +# target = target.to(dtype) + +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=StatScores, +# sk_metric=partial( +# sk_fn, +# reduce=reduce, +# mdmc_reduce=mdmc_reduce, +# num_classes=num_classes, +# multiclass=multiclass, +# ignore_index=ignore_index, +# top_k=top_k, +# threshold=threshold, +# ), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "num_classes": num_classes, +# "reduce": reduce, +# "mdmc_reduce": mdmc_reduce, +# "threshold": threshold, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "top_k": top_k, +# }, +# ) + +# def test_stat_scores_fn( +# self, +# sk_fn: Callable, +# preds: Tensor, +# target: Tensor, +# reduce: str, +# mdmc_reduce: Optional[str], +# num_classes: Optional[int], +# multiclass: Optional[bool], +# ignore_index: Optional[int], +# top_k: Optional[int], +# threshold: Optional[float], +# ): +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=stat_scores, +# sk_metric=partial( +# sk_fn, +# reduce=reduce, +# mdmc_reduce=mdmc_reduce, +# num_classes=num_classes, +# multiclass=multiclass, +# ignore_index=ignore_index, +# top_k=top_k, +# threshold=threshold, +# ), +# metric_args={ +# "num_classes": num_classes, +# "reduce": reduce, +# "mdmc_reduce": mdmc_reduce, +# "threshold": threshold, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "top_k": top_k, +# }, +# ) + +# def test_stat_scores_differentiability( +# self, +# sk_fn: Callable, +# preds: Tensor, +# target: Tensor, +# reduce: str, +# mdmc_reduce: Optional[str], +# num_classes: Optional[int], +# multiclass: Optional[bool], +# ignore_index: Optional[int], +# top_k: Optional[int], +# threshold: Optional[float], +# ): +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# self.run_differentiability_test( +# preds, +# target, +# metric_module=StatScores, +# metric_functional=stat_scores, +# metric_args={ +# "num_classes": num_classes, +# "reduce": reduce, +# "mdmc_reduce": mdmc_reduce, +# "threshold": threshold, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "top_k": top_k, +# }, +# ) + + +# _mc_k_target = tensor([0, 1, 2]) +# _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +# _ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +# _ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +# @pytest.mark.parametrize( +# "k, preds, target, reduce, expected", +# [ +# (1, _mc_k_preds, _mc_k_target, "micro", tensor([2, 1, 5, 1, 3])), +# (2, _mc_k_preds, _mc_k_target, "micro", tensor([3, 3, 3, 0, 3])), +# (1, _ml_k_preds, _ml_k_target, "micro", tensor([0, 3, 3, 3, 3])), +# (2, _ml_k_preds, _ml_k_target, "micro", tensor([1, 5, 1, 2, 3])), +# (1, _mc_k_preds, _mc_k_target, "macro", tensor([[0, 1, 1], [0, 1, 0], [2, 1, 2], [1, 0, 0], [1, 1, 1]])), +# (2, _mc_k_preds, _mc_k_target, "macro", tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1], [0, 0, 0], [1, 1, 1]])), +# (1, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 0, 0], [1, 0, 2], [1, 1, 1], [1, 2, 0], [1, 2, 0]])), +# (2, _ml_k_preds, _ml_k_target, "macro", tensor([[0, 1, 0], [2, 0, 3], [0, 1, 0], [1, 1, 0], [1, 2, 0]])), +# ], +# ) +# def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Tensor): +# """A simple test to check that top_k works as expected.""" + +# class_metric = StatScores(top_k=k, reduce=reduce, num_classes=3) +# class_metric.update(preds, target) + +# assert torch.equal(class_metric.compute(), expected.T) +# assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) diff --git a/tests/unittests/helpers/testers.py b/tests/unittests/helpers/testers.py index 92008ef5a4f..288a39e4d2a 100644 --- a/tests/unittests/helpers/testers.py +++ b/tests/unittests/helpers/testers.py @@ -14,6 +14,7 @@ import os import pickle import sys +from copy import deepcopy from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Union @@ -300,12 +301,13 @@ def _functional_test( _assert_allclose(tm_result, sk_result, atol=atol) -def _assert_half_support( +def _assert_dtype_support( metric_module: Optional[Metric], metric_functional: Optional[Callable], preds: Tensor, target: Tensor, device: str = "cpu", + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if an metric can be used with half precision tensors. @@ -319,10 +321,10 @@ def _assert_half_support( kwargs_update: Additional keyword arguments that will be passed with preds and target when running update on the metric. """ - y_hat = preds[0].half().to(device) if preds[0].is_floating_point() else preds[0].to(device) - y = target[0].half().to(device) if target[0].is_floating_point() else target[0].to(device) + y_hat = preds[0].to(dtype=dtype, device=device) if preds[0].is_floating_point() else preds[0].to(device) + y = target[0].to(dtype=dtype, device=device) if target[0].is_floating_point() else target[0].to(device) kwargs_update = { - k: (v[0].half() if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v + k: (v[0].to(dtype=dtype) if v.is_floating_point() else v[0]).to(device) if isinstance(v, Tensor) else v for k, v in kwargs_update.items() } if metric_module is not None: @@ -402,7 +404,7 @@ def run_class_metric_test( target: Union[Tensor, List[Dict]], metric_class: Metric, sk_metric: Callable, - dist_sync_on_step: bool, + dist_sync_on_step: bool = False, metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, @@ -482,6 +484,7 @@ def run_precision_test_cpu( metric_module: Optional[Metric] = None, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if a metric can be used with half precision tensors on cpu @@ -495,12 +498,13 @@ def run_precision_test_cpu( target when running update on the metric. """ metric_args = metric_args or {} - _assert_half_support( + _assert_dtype_support( metric_module(**metric_args) if metric_module is not None else None, - metric_functional, + partial(metric_functional, **metric_args) if metric_functional is not None else None, preds, target, device="cpu", + dtype=dtype, **kwargs_update, ) @@ -511,6 +515,7 @@ def run_precision_test_gpu( metric_module: Optional[Metric] = None, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + dtype: torch.dtype = torch.half, **kwargs_update, ): """Test if a metric can be used with half precision tensors on gpu @@ -524,12 +529,13 @@ def run_precision_test_gpu( target when running update on the metric. """ metric_args = metric_args or {} - _assert_half_support( + _assert_dtype_support( metric_module(**metric_args) if metric_module is not None else None, - metric_functional, + partial(metric_functional, **metric_args) if metric_functional is not None else None, preds, target, device="cuda", + dtype=dtype, **kwargs_update, ) @@ -619,3 +625,10 @@ def compute(self): class DummyMetricMultiOutput(DummyMetricSum): def compute(self): return [self.x, self.x] + + +def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: + idx = torch.randperm(x.numel()) + x = deepcopy(x) + x.view(-1)[idx[::5]] = ignore_index + return x diff --git a/tests/unittests/retrieval/helpers.py b/tests/unittests/retrieval/helpers.py index 2561d0fb5c7..4c41479df80 100644 --- a/tests/unittests/retrieval/helpers.py +++ b/tests/unittests/retrieval/helpers.py @@ -485,7 +485,7 @@ def run_precision_test_cpu( metric_module: Metric, metric_functional: Callable, ): - def metric_functional_ignore_indexes(preds, target, indexes): + def metric_functional_ignore_indexes(preds, target, indexes, empty_target_action): return metric_functional(preds, target) super().run_precision_test_cpu( @@ -508,7 +508,7 @@ def run_precision_test_gpu( if not torch.cuda.is_available(): pytest.skip("Test requires GPU") - def metric_functional_ignore_indexes(preds, target, indexes): + def metric_functional_ignore_indexes(preds, target, indexes, empty_target_action): return metric_functional(preds, target) super().run_precision_test_gpu( diff --git a/tests/unittests/test_utilities.py b/tests/unittests/test_utilities.py index 9f5a5ccc222..d88ac3d9eb3 100644 --- a/tests/unittests/test_utilities.py +++ b/tests/unittests/test_utilities.py @@ -20,6 +20,7 @@ from torchmetrics.utilities.checks import _allclose_recursive from torchmetrics.utilities.data import _bincount, _flatten, _flatten_dict, to_categorical, to_onehot from torchmetrics.utilities.distributed import class_reduce, reduce +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_7 def test_prints(): @@ -113,7 +114,7 @@ def test_bincount(): """test that bincount works in deterministic setting on GPU.""" torch.use_deterministic_algorithms(True) - x = torch.randint(100, size=(100,)) + x = torch.randint(10, size=(100,)) # uses custom implementation res1 = _bincount(x, minlength=10) @@ -157,3 +158,14 @@ def test_check_full_state_update_fn(capsys, metric_class, expected): def test_recursive_allclose(input, expected): res = _allclose_recursive(*input) assert res == expected + + +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="test requires access to `torch.movedim`") +@pytest.mark.parametrize("dim1, dim2", [(1, 3), (1, -1)]) +def test_movedim(dim1, dim2): + x = torch.randn(5, 4, 3, 2, 1) + res1 = torch.movedim(x, dim1, dim2) + if dim2 >= 0: + dim2 += 1 + res2 = x.unsqueeze(dim2).transpose(dim2, dim1).squeeze(dim1) + assert torch.allclose(res1, res2)