diff --git a/docs/source/classification/hamming_distance.rst b/docs/source/classification/hamming_distance.rst index 4af52d20492..b86c41b2278 100644 --- a/docs/source/classification/hamming_distance.rst +++ b/docs/source/classification/hamming_distance.rst @@ -10,11 +10,53 @@ Hamming Distance Module Interface ________________ +HammingDistance +^^^^^^^^^^^^^^^ + .. autoclass:: torchmetrics.HammingDistance :noindex: +BinaryHammingDistance +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinaryHammingDistance + :noindex: + +MulticlassHammingDistance +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassHammingDistance + :noindex: + +MultilabelHammingDistance +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelHammingDistance + :noindex: + Functional Interface ____________________ +hamming_distance +^^^^^^^^^^^^^^^^ + .. autofunction:: torchmetrics.functional.hamming_distance :noindex: + +binary_hamming_distance +^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_hamming_distance + :noindex: + +multiclass_hamming_distance +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_hamming_distance + :noindex: + +multilabel_hamming_distance +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_hamming_distance + :noindex: diff --git a/docs/source/classification/specificity.rst b/docs/source/classification/specificity.rst index 4d0aef5eda4..00e0bbb8932 100644 --- a/docs/source/classification/specificity.rst +++ b/docs/source/classification/specificity.rst @@ -13,8 +13,45 @@ ________________ .. autoclass:: torchmetrics.Specificity :noindex: +BinarySpecificity +^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.BinarySpecificity + :noindex: + +MulticlassSpecificity +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MulticlassSpecificity + :noindex: + +MultilabelSpecificity +^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.MultilabelSpecificity + :noindex: + + Functional Interface ____________________ .. autofunction:: torchmetrics.functional.specificity :noindex: + +binary_specificity +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.binary_specificity + :noindex: + +multiclass_specificity +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multiclass_specificity + :noindex: + +multilabel_specificity +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.multilabel_specificity + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index aa5c100d890..5230a56d1bd 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -30,10 +30,12 @@ BinaryConfusionMatrix, BinaryF1Score, BinaryFBetaScore, + BinaryHammingDistance, BinaryJaccardIndex, BinaryMatthewsCorrCoef, BinaryPrecision, BinaryRecall, + BinarySpecificity, BinaryStatScores, BinnedAveragePrecision, BinnedPrecisionRecallCurve, @@ -56,18 +58,22 @@ MulticlassConfusionMatrix, MulticlassF1Score, MulticlassFBetaScore, + MulticlassHammingDistance, MulticlassJaccardIndex, MulticlassMatthewsCorrCoef, MulticlassPrecision, MulticlassRecall, + MulticlassSpecificity, MulticlassStatScores, MultilabelConfusionMatrix, MultilabelF1Score, MultilabelFBetaScore, + MultilabelHammingDistance, MultilabelJaccardIndex, MultilabelMatthewsCorrCoef, MultilabelPrecision, MultilabelRecall, + MultilabelSpecificity, MultilabelStatScores, Precision, PrecisionRecallCurve, @@ -173,6 +179,9 @@ "MulticlassFBetaScore", "MultilabelFBetaScore", "HammingDistance", + "BinaryHammingDistance", + "MultilabelHammingDistance", + "MulticlassHammingDistance", "HingeLoss", "JaccardIndex", "BinaryJaccardIndex", @@ -231,6 +240,9 @@ "SignalNoiseRatio", "SpearmanCorrCoef", "Specificity", + "BinarySpecificity", + "MulticlassSpecificity", + "MultilabelSpecificity", "SpectralAngleMapper", "SpectralDistortionIndex", "SQuAD", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 8466be98e55..2a08575de92 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -44,7 +44,12 @@ MultilabelF1Score, MultilabelFBetaScore, ) -from torchmetrics.classification.hamming import HammingDistance # noqa: F401 +from torchmetrics.classification.hamming import ( # noqa: F401 + BinaryHammingDistance, + HammingDistance, + MulticlassHammingDistance, + MultilabelHammingDistance, +) from torchmetrics.classification.hinge import HingeLoss # noqa: F401 from torchmetrics.classification.jaccard import ( # noqa: F401 BinaryJaccardIndex, @@ -76,4 +81,9 @@ LabelRankingLoss, ) from torchmetrics.classification.roc import ROC # noqa: F401 -from torchmetrics.classification.specificity import Specificity # noqa: F401 +from torchmetrics.classification.specificity import ( # noqa: F401 + BinarySpecificity, + MulticlassSpecificity, + MultilabelSpecificity, + Specificity, +) diff --git a/src/torchmetrics/classification/hamming.py b/src/torchmetrics/classification/hamming.py index be9a0bf430e..0a9a69712b2 100644 --- a/src/torchmetrics/classification/hamming.py +++ b/src/torchmetrics/classification/hamming.py @@ -16,10 +16,309 @@ import torch from torch import Tensor, tensor -from torchmetrics.functional.classification.hamming import _hamming_distance_compute, _hamming_distance_update +from torchmetrics.classification.stat_scores import BinaryStatScores, MulticlassStatScores, MultilabelStatScores +from torchmetrics.functional.classification.hamming import ( + _hamming_distance_compute, + _hamming_distance_reduce, + _hamming_distance_update, +) from torchmetrics.metric import Metric +class BinaryHammingDistance(BinaryStatScores): + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for binary tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics import BinaryHammingDistance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinaryHammingDistance() + >>> metric(preds, target) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics import BinaryHammingDistance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinaryHammingDistance() + >>> metric(preds, target) + tensor(0.3333) + + Example (multidim tensors): + >>> from torchmetrics import BinaryHammingDistance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinaryHammingDistance(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.6667, 0.8333]) + """ + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + + +class MulticlassHammingDistance(MulticlassStatScores): + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics import MulticlassHammingDistance + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassHammingDistance(num_classes=3) + >>> metric(preds, target) + tensor(0.2500) + >>> metric = MulticlassHammingDistance(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 0.0000, 0.0000]) + + Example (preds is float tensor): + >>> from torchmetrics import MulticlassHammingDistance + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassHammingDistance(num_classes=3) + >>> metric(preds, target) + tensor(0.2500) + >>> metric = MulticlassHammingDistance(num_classes=3, average=None) + >>> metric(preds, target) + tensor([0.5000, 0.0000, 0.0000]) + + Example (multidim tensors): + >>> from torchmetrics import MulticlassHammingDistance + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.5000, 0.6667]) + >>> metric = MulticlassHammingDistance(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.0000, 1.0000, 0.5000], + [1.0000, 0.6667, 0.5000]]) + """ + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _hamming_distance_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + +class MultilabelHammingDistance(MultilabelStatScores): + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics import MultilabelHammingDistance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelHammingDistance(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + >>> metric = MultilabelHammingDistance(num_labels=3, average=None) + >>> metric(preds, target) + tensor([0.0000, 0.5000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics import MultilabelHammingDistance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelHammingDistance(num_labels=3) + >>> metric(preds, target) + tensor(0.3333) + >>> metric = MultilabelHammingDistance(num_labels=3, average=None) + >>> metric(preds, target) + tensor([0.0000, 0.5000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics import MultilabelHammingDistance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.6667, 0.8333]) + >>> metric = MultilabelHammingDistance(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.5000, 0.5000, 1.0000], + [1.0000, 1.0000, 0.5000]]) + + """ + + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = False + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _hamming_distance_reduce( + tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, multilabel=True + ) + + +# -------------------------- Old stuff -------------------------- + + class HammingDistance(Metric): r"""Computes the average `Hamming distance`_ (also known as Hamming loss) between targets and predictions: diff --git a/src/torchmetrics/classification/specificity.py b/src/torchmetrics/classification/specificity.py index 56057a9cfbd..2330be52263 100644 --- a/src/torchmetrics/classification/specificity.py +++ b/src/torchmetrics/classification/specificity.py @@ -16,11 +16,284 @@ import torch from torch import Tensor -from torchmetrics.classification.stat_scores import StatScores -from torchmetrics.functional.classification.specificity import _specificity_compute +from torchmetrics.classification.stat_scores import ( + BinaryStatScores, + MulticlassStatScores, + MultilabelStatScores, + StatScores, +) +from torchmetrics.functional.classification.specificity import _specificity_compute, _specificity_reduce from torchmetrics.utilities.enums import AverageMethod +class BinarySpecificity(BinaryStatScores): + r"""Computes `Specificity`_ for binary tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics import BinarySpecificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> metric = BinarySpecificity() + >>> metric(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics import BinarySpecificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> metric = BinarySpecificity() + >>> metric(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics import BinarySpecificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = BinarySpecificity(multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.0000, 0.3333]) + """ + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=self.multidim_average) + + +class MulticlassSpecificity(MulticlassStatScores): + r"""Computes `Specificity`_ for multiclass tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics import MulticlassSpecificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> metric = MulticlassSpecificity(num_classes=3) + >>> metric(preds, target) + tensor(0.8750) + >>> metric = MulticlassSpecificity(num_classes=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.6667, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics import MulticlassSpecificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> metric = MulticlassSpecificity(num_classes=3) + >>> metric(preds, target) + tensor(0.8750) + >>> metric = MulticlassSpecificity(num_classes=3, average=None) + >>> metric(preds, target) + tensor([1.0000, 0.6667, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics import MulticlassSpecificity + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.7500, 0.6667]) + >>> metric = MulticlassSpecificity(num_classes=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0.7500, 0.7500, 0.7500], + [0.8000, 0.6667, 0.5000]]) + """ + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + +class MultilabelSpecificity(MultilabelStatScores): + r"""Computes `Specificity`_ for multilabel tasks + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + - If ``multidim_average`` is set to ``global`` + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + - If ``multidim_average`` is set to ``samplewise`` + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics import MultilabelSpecificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> metric = MultilabelSpecificity(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelSpecificity(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1., 1., 0.]) + + Example (preds is float tensor): + >>> from torchmetrics import MultilabelSpecificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> metric = MultilabelSpecificity(num_labels=3) + >>> metric(preds, target) + tensor(0.6667) + >>> metric = MultilabelSpecificity(num_labels=3, average=None) + >>> metric(preds, target) + tensor([1., 1., 0.]) + + Example (multidim tensors): + >>> from torchmetrics import MultilabelSpecificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> metric = MultilabelSpecificity(num_labels=3, multidim_average='samplewise') + >>> metric(preds, target) + tensor([0.0000, 0.3333]) + >>> metric = MultilabelSpecificity(num_labels=3, multidim_average='samplewise', average=None) + >>> metric(preds, target) + tensor([[0., 0., 0.], + [0., 0., 1.]]) + """ + + def compute(self) -> Tensor: + tp, fp, tn, fn = self._final_state() + return _specificity_reduce(tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average) + + +# -------------------------- Old stuff -------------------------- + + class Specificity(StatScores): r"""Computes `Specificity`_: diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 652bb1bb612..321311cb60a 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -37,7 +37,12 @@ multilabel_f1_score, multilabel_fbeta_score, ) -from torchmetrics.functional.classification.hamming import hamming_distance +from torchmetrics.functional.classification.hamming import ( + binary_hamming_distance, + hamming_distance, + multiclass_hamming_distance, + multilabel_hamming_distance, +) from torchmetrics.functional.classification.hinge import hinge_loss from torchmetrics.functional.classification.jaccard import ( binary_jaccard_index, @@ -70,7 +75,12 @@ label_ranking_loss, ) from torchmetrics.functional.classification.roc import roc -from torchmetrics.functional.classification.specificity import specificity +from torchmetrics.functional.classification.specificity import ( + binary_specificity, + multiclass_specificity, + multilabel_specificity, + specificity, +) from torchmetrics.functional.classification.stat_scores import ( binary_stat_scores, multiclass_stat_scores, @@ -231,6 +241,12 @@ "binary_matthews_corrcoef", "multiclass_matthews_corrcoef", "multilabel_matthews_corrcoef", + "binary_specificity", + "multiclass_specificity", + "multilabel_specificity", + "binary_hamming_distance", + "multiclass_hamming_distance", + "multilabel_hamming_distance", "binary_precision", "multiclass_precision", "multilabel_precision", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 4026a15df92..05ab3b92ca9 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -34,7 +34,12 @@ multilabel_f1_score, multilabel_fbeta_score, ) -from torchmetrics.functional.classification.hamming import hamming_distance # noqa: F401 +from torchmetrics.functional.classification.hamming import ( # noqa: F401 + binary_hamming_distance, + hamming_distance, + multiclass_hamming_distance, + multilabel_hamming_distance, +) from torchmetrics.functional.classification.hinge import hinge_loss # noqa: F401 from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 from torchmetrics.functional.classification.kl_divergence import kl_divergence # noqa: F401 @@ -57,7 +62,12 @@ label_ranking_loss, ) from torchmetrics.functional.classification.roc import roc # noqa: F401 -from torchmetrics.functional.classification.specificity import specificity # noqa: F401 +from torchmetrics.functional.classification.specificity import ( # noqa: F401 + binary_specificity, + multiclass_specificity, + multilabel_specificity, + specificity, +) from torchmetrics.functional.classification.stat_scores import ( # noqa: F401 binary_stat_scores, multiclass_stat_scores, diff --git a/src/torchmetrics/functional/classification/hamming.py b/src/torchmetrics/functional/classification/hamming.py index 44fedd577ae..7e49819cc0b 100644 --- a/src/torchmetrics/functional/classification/hamming.py +++ b/src/torchmetrics/functional/classification/hamming.py @@ -11,12 +11,382 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import Tensor +from typing_extensions import Literal +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, +) from torchmetrics.utilities.checks import _input_format_classification +from torchmetrics.utilities.compute import _safe_divide + + +def _hamming_distance_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", + multilabel: bool = False, +) -> Tensor: + """Reduce classification statistics into hamming distance + Args: + tp: number of true positives + fp: number of false positives + tn: number of true negatives + fn: number of false negatives + normalize: normalization method. + + - `"true"` will divide by the sum of the column dimension. + - `"pred"` will divide by the sum of the row dimension. + - `"all"` will divide by the sum of the full matrix + - `"none"` or `None` will apply no reduction + + multilabel: bool indicating if reduction is for multilabel tasks + + Returns: + Accuracy score + """ + if average == "binary": + return 1 - _safe_divide(tp + tn, tp + fp + tn + fn) + elif average == "micro": + tp = tp.sum(dim=0 if multidim_average == "global" else 1) + fn = fn.sum(dim=0 if multidim_average == "global" else 1) + if multilabel: + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + tn = tn.sum(dim=0 if multidim_average == "global" else 1) + return 1 - _safe_divide(tp + tn, tp + tn + fp + fn) + return 1 - _safe_divide(tp, tp + fn) + else: + if multilabel: + score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) + else: + score = 1 - _safe_divide(tp, tp + fn) + if average is None or average == "none": + return score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(score) + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) + + +def binary_hamming_distance( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for binary tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_hamming_distance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_hamming_distance(preds, target) + tensor(0.3333) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_hamming_distance + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_hamming_distance(preds, target) + tensor(0.3333) + + Example (multidim tensors): + >>> from torchmetrics.functional import binary_hamming_distance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_hamming_distance(preds, target, multidim_average='samplewise') + tensor([0.6667, 0.8333]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _hamming_distance_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_hamming_distance( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multiclass tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multiclass_hamming_distance + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_hamming_distance(preds, target, num_classes=3) + tensor(0.2500) + >>> multiclass_hamming_distance(preds, target, num_classes=3, average=None) + tensor([0.5000, 0.0000, 0.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multiclass_hamming_distance + >>> target = target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_hamming_distance(preds, target, num_classes=3) + tensor(0.2500) + >>> multiclass_hamming_distance(preds, target, num_classes=3, average=None) + tensor([0.5000, 0.0000, 0.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multiclass_hamming_distance + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.5000, 0.6667]) + >>> multiclass_hamming_distance(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.0000, 1.0000, 0.5000], + [1.0000, 0.6667, 0.5000]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_hamming_distance( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes the average `Hamming distance`_ (also known as Hamming loss) for multilabel tasks: + + .. math:: + \text{Hamming distance} = \frac{1}{N \cdot L} \sum_i^N \sum_l^L 1(y_{il} \neq \hat{y}_{il}) + + Where :math:`y` is a tensor of target values, :math:`\hat{y}` is a tensor of predictions, + and :math:`\bullet_{il}` refers to the :math:`l`-th label of the :math:`i`-th sample of that + tensor. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_hamming_distance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_hamming_distance(preds, target, num_labels=3) + tensor(0.3333) + >>> multilabel_hamming_distance(preds, target, num_labels=3, average=None) + tensor([0.0000, 0.5000, 0.5000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_hamming_distance + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_hamming_distance(preds, target, num_labels=3) + tensor(0.3333) + >>> multilabel_hamming_distance(preds, target, num_labels=3, average=None) + tensor([0.0000, 0.5000, 0.5000]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multilabel_hamming_distance + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.6667, 0.8333]) + >>> multilabel_hamming_distance(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0.5000, 0.5000, 1.0000], + [1.0000, 1.0000, 0.5000]]) + + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _hamming_distance_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average, multilabel=True) + + +# -------------------------- Old stuff -------------------------- def _hamming_distance_update( diff --git a/src/torchmetrics/functional/classification/specificity.py b/src/torchmetrics/functional/classification/specificity.py index 193f01fe395..66527083755 100644 --- a/src/torchmetrics/functional/classification/specificity.py +++ b/src/torchmetrics/functional/classification/specificity.py @@ -15,11 +15,349 @@ import torch from torch import Tensor - -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores, _stat_scores_update +from typing_extensions import Literal + +from torchmetrics.functional.classification.stat_scores import ( + _binary_stat_scores_arg_validation, + _binary_stat_scores_format, + _binary_stat_scores_tensor_validation, + _binary_stat_scores_update, + _multiclass_stat_scores_arg_validation, + _multiclass_stat_scores_format, + _multiclass_stat_scores_tensor_validation, + _multiclass_stat_scores_update, + _multilabel_stat_scores_arg_validation, + _multilabel_stat_scores_format, + _multilabel_stat_scores_tensor_validation, + _multilabel_stat_scores_update, + _reduce_stat_scores, + _stat_scores_update, +) +from torchmetrics.utilities.compute import _safe_divide from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod +def _specificity_reduce( + tp: Tensor, + fp: Tensor, + tn: Tensor, + fn: Tensor, + average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], + multidim_average: Literal["global", "samplewise"] = "global", +) -> Tensor: + if average == "binary": + return _safe_divide(tn, tn + fp) + elif average == "micro": + tn = tn.sum(dim=0 if multidim_average == "global" else 1) + fp = fp.sum(dim=0 if multidim_average == "global" else 1) + return _safe_divide(tn, tn + fp) + else: + specificity_score = _safe_divide(tn, tn + fp) + if average is None or average == "none": + return specificity_score + if average == "weighted": + weights = tp + fn + else: + weights = torch.ones_like(specificity_score) + return _safe_divide(weights * specificity_score, weights.sum(-1, keepdim=True)).sum(-1) + + +def binary_specificity( + preds: Tensor, + target: Tensor, + threshold: float = 0.5, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Specificity`_ for binary tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + threshold: Threshold for transforming probability to binary {0,1} predictions + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + If ``multidim_average`` is set to ``global``, the metric returns a scalar value. If ``multidim_average`` + is set to ``samplewise``, the metric returns ``(N,)`` vector consisting of a scalar value per sample. + + Example (preds is int tensor): + >>> from torchmetrics.functional import binary_specificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0, 0, 1, 1, 0, 1]) + >>> binary_specificity(preds, target) + tensor(0.6667) + + Example (preds is float tensor): + >>> from torchmetrics.functional import binary_specificity + >>> target = torch.tensor([0, 1, 0, 1, 0, 1]) + >>> preds = torch.tensor([0.11, 0.22, 0.84, 0.73, 0.33, 0.92]) + >>> binary_specificity(preds, target) + tensor(0.6667) + + Example (multidim tensors): + >>> from torchmetrics.functional import binary_specificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> binary_specificity(preds, target, multidim_average='samplewise') + tensor([0.0000, 0.3333]) + """ + if validate_args: + _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index) + _binary_stat_scores_tensor_validation(preds, target, multidim_average, ignore_index) + preds, target = _binary_stat_scores_format(preds, target, threshold, ignore_index) + tp, fp, tn, fn = _binary_stat_scores_update(preds, target, multidim_average) + return _specificity_reduce(tp, fp, tn, fn, average="binary", multidim_average=multidim_average) + + +def multiclass_specificity( + preds: Tensor, + target: Tensor, + num_classes: int, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + top_k: int = 1, + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Specificity`_ for multiclass tasks: + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds``: ``(N, ...)`` (int tensor) or ``(N, C, ..)`` (float tensor). If preds is a floating point + we apply ``torch.argmax`` along the ``C`` dimension to automatically convert probabilities/logits into + an int tensor. + - ``target`` (int tensor): ``(N, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifing the number of classes + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + top_k: + Number of highest probability or logit score predictions considered to find the correct label. + Only works when ``preds`` contain probabilities/logits. + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multiclass_specificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([2, 1, 0, 1]) + >>> multiclass_specificity(preds, target, num_classes=3) + tensor(0.8750) + >>> multiclass_specificity(preds, target, num_classes=3, average=None) + tensor([1.0000, 0.6667, 1.0000]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multiclass_specificity + >>> target = torch.tensor([2, 1, 0, 0]) + >>> preds = torch.tensor([ + ... [0.16, 0.26, 0.58], + ... [0.22, 0.61, 0.17], + ... [0.71, 0.09, 0.20], + ... [0.05, 0.82, 0.13], + ... ]) + >>> multiclass_specificity(preds, target, num_classes=3) + tensor(0.8750) + >>> multiclass_specificity(preds, target, num_classes=3, average=None) + tensor([1.0000, 0.6667, 1.0000]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multiclass_specificity + >>> target = torch.tensor([[[0, 1], [2, 1], [0, 2]], [[1, 1], [2, 0], [1, 2]]]) + >>> preds = torch.tensor([[[0, 2], [2, 0], [0, 1]], [[2, 2], [2, 1], [1, 0]]]) + >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise') + tensor([0.7500, 0.6667]) + >>> multiclass_specificity(preds, target, num_classes=3, multidim_average='samplewise', average=None) + tensor([[0.7500, 0.7500, 0.7500], + [0.8000, 0.6667, 0.5000]]) + """ + if validate_args: + _multiclass_stat_scores_arg_validation(num_classes, top_k, average, multidim_average, ignore_index) + _multiclass_stat_scores_tensor_validation(preds, target, num_classes, multidim_average, ignore_index) + preds, target = _multiclass_stat_scores_format(preds, target, top_k) + tp, fp, tn, fn = _multiclass_stat_scores_update(preds, target, num_classes, top_k, multidim_average, ignore_index) + return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +def multilabel_specificity( + preds: Tensor, + target: Tensor, + num_labels: int, + threshold: float = 0.5, + average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", + multidim_average: Literal["global", "samplewise"] = "global", + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Computes `Specificity`_ for multilabel tasks + + .. math:: \text{Specificity} = \frac{\text{TN}}{\text{TN} + \text{FP}} + + Where :math:`\text{TN}` and :math:`\text{FP}` represent the number of true negatives and + false positives respecitively. + + Accepts the following input tensors: + + - ``preds`` (int or float tensor): ``(N, C, ...)``. If preds is a floating point tensor with values outside + [0,1] range we consider the input to be logits and will auto apply sigmoid per element. Addtionally, + we convert to int tensor with thresholding using the value in ``threshold``. + - ``target`` (int tensor): ``(N, C, ...)`` + + The influence of the additional dimension ``...`` (if present) will be determined by the `multidim_average` + argument. + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifing the number of labels + threshold: Threshold for transforming probability to binary (0,1) predictions + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``micro``: Sum statistics over all labels + - ``macro``: Calculate statistics for each label and average them + - ``weighted``: Calculates statistics for each label and computes weighted average using their support + - ``"none"`` or ``None``: Calculates statistic for each label and applies no reduction + + multidim_average: + Defines how additionally dimensions ``...`` should be handled. Should be one of the following: + + - ``global``: Additional dimensions are flatted along the batch dimension + - ``samplewise``: Statistic will be calculated independently for each sample on the ``N`` axis. + The statistics in this case are calculated over the additional dimensions. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + The returned shape depends on the ``average`` and ``multidim_average`` arguments: + + - If ``multidim_average`` is set to ``global``: + + - If ``average='micro'/'macro'/'weighted'``, the output will be a scalar tensor + - If ``average=None/'none'``, the shape will be ``(C,)`` + + - If ``multidim_average`` is set to ``samplewise``: + + - If ``average='micro'/'macro'/'weighted'``, the shape will be ``(N,)`` + - If ``average=None/'none'``, the shape will be ``(N, C)`` + + Example (preds is int tensor): + >>> from torchmetrics.functional import multilabel_specificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0, 0, 1], [1, 0, 1]]) + >>> multilabel_specificity(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_specificity(preds, target, num_labels=3, average=None) + tensor([1., 1., 0.]) + + Example (preds is float tensor): + >>> from torchmetrics.functional import multilabel_specificity + >>> target = torch.tensor([[0, 1, 0], [1, 0, 1]]) + >>> preds = torch.tensor([[0.11, 0.22, 0.84], [0.73, 0.33, 0.92]]) + >>> multilabel_specificity(preds, target, num_labels=3) + tensor(0.6667) + >>> multilabel_specificity(preds, target, num_labels=3, average=None) + tensor([1., 1., 0.]) + + Example (multidim tensors): + >>> from torchmetrics.functional import multilabel_specificity + >>> target = torch.tensor([[[0, 1], [1, 0], [0, 1]], [[1, 1], [0, 0], [1, 0]]]) + >>> preds = torch.tensor( + ... [ + ... [[0.59, 0.91], [0.91, 0.99], [0.63, 0.04]], + ... [[0.38, 0.04], [0.86, 0.780], [0.45, 0.37]], + ... ] + ... ) + >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise') + tensor([0.0000, 0.3333]) + >>> multilabel_specificity(preds, target, num_labels=3, multidim_average='samplewise', average=None) + tensor([[0., 0., 0.], + [0., 0., 1.]]) + """ + if validate_args: + _multilabel_stat_scores_arg_validation(num_labels, threshold, average, multidim_average, ignore_index) + _multilabel_stat_scores_tensor_validation(preds, target, num_labels, multidim_average, ignore_index) + preds, target = _multilabel_stat_scores_format(preds, target, num_labels, threshold, ignore_index) + tp, fp, tn, fn = _multilabel_stat_scores_update(preds, target, multidim_average) + return _specificity_reduce(tp, fp, tn, fn, average=average, multidim_average=multidim_average) + + +# -------------------------- Old stuff -------------------------- + + def _specificity_compute( tp: Tensor, fp: Tensor, diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 6fb4c46fd7a..72d8060757b 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -349,6 +349,66 @@ def test_top_k( assert torch.isclose(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) +def _sk_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + + fbeta_score, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + fbeta_score.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(fbeta_score, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average): + fbeta_score, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + fbeta_score.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + fbeta_score.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(fbeta_score) + res = np.stack(fbeta_score, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + def _sk_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() @@ -365,61 +425,8 @@ def _sk_fbeta_score_multilabel(preds, target, sk_fn, ignore_index, multidim_aver average=average, ) elif multidim_average == "global": - if average == "micro": - preds = preds.flatten() - target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) - - fbeta_score, weights = [], [] - for i in range(preds.shape[1]): - pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - fbeta_score.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - weights.append(confmat[1, 1] + confmat[1, 0]) - res = np.stack(fbeta_score, axis=0) - - if average == "macro": - return res.mean(0) - elif average == "weighted": - weights = np.stack(weights, 0).astype(float) - weights_norm = weights.sum(-1, keepdims=True) - weights_norm[weights_norm == 0] = 1.0 - return ((weights * res) / weights_norm).sum(-1) - elif average is None or average == "none": - return res - else: - fbeta_score, weights = [], [] - for i in range(preds.shape[0]): - if average == "micro": - pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - fbeta_score.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - weights.append(confmat[1, 1] + confmat[1, 0]) - else: - scores, w = [], [] - for j in range(preds.shape[1]): - pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - w.append(confmat[1, 1] + confmat[1, 0]) - fbeta_score.append(np.stack(scores)) - weights.append(np.stack(w)) - if average == "micro": - return np.array(fbeta_score) - res = np.stack(fbeta_score, 0) - if average == "macro": - return res.mean(-1) - elif average == "weighted": - weights = np.stack(weights, 0).astype(float) - weights_norm = weights.sum(-1, keepdims=True) - weights_norm[weights_norm == 0] = 1.0 - return ((weights * res) / weights_norm).sum(-1) - elif average is None or average == "none": - return res + return _sk_fbeta_score_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _sk_fbeta_score_multilabel_local(preds, target, sk_fn, ignore_index, average) @pytest.mark.parametrize("input", _multilabel_cases) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index 11fab91d66b..6011ad5cbcc 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -11,96 +11,550 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + +import numpy as np import pytest +import torch +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix from sklearn.metrics import hamming_loss as sk_hamming_loss -from torchmetrics import HammingDistance -from torchmetrics.functional import hamming_distance -from torchmetrics.utilities.checks import _input_format_classification -from unittests.classification.inputs import _input_binary, _input_binary_logits, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_logits as _input_mcls_logits -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_logits as _input_mlb_logits -from unittests.classification.inputs import _input_multilabel_multidim as _input_mlmd -from unittests.classification.inputs import _input_multilabel_multidim_prob as _input_mlmd_prob -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.hamming import ( + BinaryHammingDistance, + MulticlassHammingDistance, + MultilabelHammingDistance, +) +from torchmetrics.functional.classification.hamming import ( + binary_hamming_distance, + multiclass_hamming_distance, + multilabel_hamming_distance, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index, remove_ignore_index seed_all(42) -def _sk_hamming_loss(preds, target): - sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) - sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() - sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) - - return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) - - -@pytest.mark.parametrize( - "preds, target", - [ - (_input_binary_logits.preds, _input_binary_logits.target), - (_input_binary_prob.preds, _input_binary_prob.target), - (_input_binary.preds, _input_binary.target), - (_input_mlb_logits.preds, _input_mlb_logits.target), - (_input_mlb_prob.preds, _input_mlb_prob.target), - (_input_mlb.preds, _input_mlb.target), - (_input_mcls_logits.preds, _input_mcls_logits.target), - (_input_mcls_prob.preds, _input_mcls_prob.target), - (_input_mcls.preds, _input_mcls.target), - (_input_mdmc_prob.preds, _input_mdmc_prob.target), - (_input_mdmc.preds, _input_mdmc.target), - (_input_mlmd_prob.preds, _input_mlmd_prob.target), - (_input_mlmd.preds, _input_mlmd.target), - ], -) -class TestHammingDistance(MetricTester): - @pytest.mark.parametrize("ddp", [True, False]) - @pytest.mark.parametrize("dist_sync_on_step", [False, True]) - def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): +def _sk_hamming_loss(target, preds): + score = sk_hamming_loss(target, preds) + return score if not np.isnan(score) else 1.0 + + +def _sk_hamming_distance_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() + else: + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sk_hamming_loss(target, preds) + else: + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + res.append(_sk_hamming_loss(true, pred)) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinaryHammingDistance(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_hamming_distance(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=HammingDistance, - sk_metric=_sk_hamming_loss, - dist_sync_on_step=dist_sync_on_step, - metric_args={"threshold": THRESHOLD}, + metric_class=BinaryHammingDistance, + sk_metric=partial( + _sk_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, ) - def test_hamming_distance_fn(self, preds, target): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_hamming_distance_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + self.run_functional_metric_test( preds=preds, target=target, - metric_functional=hamming_distance, - sk_metric=_sk_hamming_loss, - metric_args={"threshold": THRESHOLD}, + metric_functional=binary_hamming_distance, + sk_metric=partial( + _sk_hamming_distance_binary, ignore_index=ignore_index, multidim_average=multidim_average + ), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, ) - def test_hamming_distance_differentiability(self, preds, target): + def test_binary_hamming_distance_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=HammingDistance, - metric_functional=hamming_distance, + metric_module=BinaryHammingDistance, + metric_functional=binary_hamming_distance, + metric_args={"threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hamming_distance_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryHammingDistance, + metric_functional=binary_hamming_distance, metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_hamming_distance_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryHammingDistance, + metric_functional=binary_hamming_distance, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) + + +def _sk_hamming_distance_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + if average == "micro": + return _sk_hamming_loss(target, preds) + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) + hamming_per_class[np.isnan(hamming_per_class)] = 1.0 + if average == "macro": + return hamming_per_class.mean() + elif average == "weighted": + weights = confmat.sum(1) + return ((weights * hamming_per_class) / weights.sum()).sum() + return hamming_per_class + + +def _sk_hamming_distance_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + if average == "micro": + res.append(_sk_hamming_loss(true, pred)) + else: + confmat = sk_confusion_matrix(true, pred, labels=list(range(NUM_CLASSES))) + hamming_per_class = 1 - confmat.diagonal() / confmat.sum(axis=1) + hamming_per_class[np.isnan(hamming_per_class)] = 1.0 + if average == "macro": + res.append(hamming_per_class.mean()) + elif average == "weighted": + weights = confmat.sum(1) + score = ((weights * hamming_per_class) / weights.sum()).sum() + res.append(0.0 if np.isnan(score) else score) + else: + res.append(hamming_per_class) + return np.stack(res, 0) + + +def _sk_hamming_distance_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _sk_hamming_distance_multiclass_global(preds, target, ignore_index, average) + return _sk_hamming_distance_multiclass_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassHammingDistance(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_hamming_distance(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassHammingDistance, + sk_metric=partial( + _sk_hamming_distance_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) + def test_multiclass_hamming_distance_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_hamming_distance, + sk_metric=partial( + _sk_hamming_distance_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) + + def test_multiclass_hamming_distance_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassHammingDistance, + metric_functional=multiclass_hamming_distance, + metric_args={"num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hamming_distance_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassHammingDistance, + metric_functional=multiclass_hamming_distance, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_hamming_distance_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassHammingDistance, + metric_functional=multiclass_hamming_distance, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + +def _sk_hamming_distance_multilabel_global(preds, target, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return _sk_hamming_loss(target, preds) + + hamming, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + hamming.append(_sk_hamming_loss(true, pred)) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(hamming, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_hamming_distance_multilabel_local(preds, target, ignore_index, average): + hamming, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + hamming.append(_sk_hamming_loss(true, pred)) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(_sk_hamming_loss(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + hamming.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(hamming) + res = np.stack(hamming, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_hamming_distance_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + + if multidim_average == "global": + return _sk_hamming_distance_multilabel_global(preds, target, ignore_index, average) + return _sk_hamming_distance_multilabel_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelHammingDistance(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_hamming_distance(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelHammingDistance, + sk_metric=partial( + _sk_hamming_distance_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_hamming_distance_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_hamming_distance, + sk_metric=partial( + _sk_hamming_distance_multilabel, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "num_labels": NUM_CLASSES, + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + }, + ) + + def test_multilabel_hamming_distance_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelHammingDistance, + metric_functional=multilabel_hamming_distance, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_hamming_distance_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelHammingDistance, + metric_functional=multilabel_hamming_distance, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, ) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_hamming_distance_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelHammingDistance, + metric_functional=multilabel_hamming_distance, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) + + +# -------------------------- Old stuff -------------------------- + +# def _sk_hamming_loss(preds, target): +# sk_preds, sk_target, _ = _input_format_classification(preds, target, threshold=THRESHOLD) +# sk_preds, sk_target = sk_preds.numpy(), sk_target.numpy() +# sk_preds, sk_target = sk_preds.reshape(sk_preds.shape[0], -1), sk_target.reshape(sk_target.shape[0], -1) + +# return sk_hamming_loss(y_true=sk_target, y_pred=sk_preds) + + +# @pytest.mark.parametrize( +# "preds, target", +# [ +# (_input_binary_logits.preds, _input_binary_logits.target), +# (_input_binary_prob.preds, _input_binary_prob.target), +# (_input_binary.preds, _input_binary.target), +# (_input_mlb_logits.preds, _input_mlb_logits.target), +# (_input_mlb_prob.preds, _input_mlb_prob.target), +# (_input_mlb.preds, _input_mlb.target), +# (_input_mcls_logits.preds, _input_mcls_logits.target), +# (_input_mcls_prob.preds, _input_mcls_prob.target), +# (_input_mcls.preds, _input_mcls.target), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target), +# (_input_mdmc.preds, _input_mdmc.target), +# (_input_mlmd_prob.preds, _input_mlmd_prob.target), +# (_input_mlmd.preds, _input_mlmd.target), +# ], +# ) +# class TestHammingDistance(MetricTester): +# @pytest.mark.parametrize("ddp", [True, False]) +# @pytest.mark.parametrize("dist_sync_on_step", [False, True]) +# def test_hamming_distance_class(self, ddp, dist_sync_on_step, preds, target): +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=HammingDistance, +# sk_metric=_sk_hamming_loss, +# dist_sync_on_step=dist_sync_on_step, +# metric_args={"threshold": THRESHOLD}, +# ) + +# def test_hamming_distance_fn(self, preds, target): +# self.run_functional_metric_test( +# preds=preds, +# target=target, +# metric_functional=hamming_distance, +# sk_metric=_sk_hamming_loss, +# metric_args={"threshold": THRESHOLD}, +# ) + +# def test_hamming_distance_differentiability(self, preds, target): +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=HammingDistance, +# metric_functional=hamming_distance, +# metric_args={"threshold": THRESHOLD}, +# ) + -@pytest.mark.parametrize("threshold", [1.5]) -def test_wrong_params(threshold): - preds, target = _input_mcls_prob.preds, _input_mcls_prob.target +# @pytest.mark.parametrize("threshold", [1.5]) +# def test_wrong_params(threshold): +# preds, target = _input_mcls_prob.preds, _input_mcls_prob.target - with pytest.raises(ValueError): - ham_dist = HammingDistance(threshold=threshold) - ham_dist(preds, target) - ham_dist.compute() +# with pytest.raises(ValueError): +# ham_dist = HammingDistance(threshold=threshold) +# ham_dist(preds, target) +# ham_dist.compute() - with pytest.raises(ValueError): - hamming_distance(preds, target, threshold=threshold) +# with pytest.raises(ValueError): +# hamming_distance(preds, target, threshold=threshold) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 0a9fef9f4c2..f17daf2f4c1 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -344,6 +344,66 @@ def test_top_k( assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), result) +def _sk_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average): + if average == "micro": + preds = preds.flatten() + target = target.flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + return sk_fn(target, preds) + + precision_recall, weights = [], [] + for i in range(preds.shape[1]): + pred, true = preds[:, i].flatten(), target[:, i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + precision_recall.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + res = np.stack(precision_recall, axis=0) + + if average == "macro": + return res.mean(0) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + +def _sk_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average): + precision_recall, weights = [], [] + for i in range(preds.shape[0]): + if average == "micro": + pred, true = preds[i].flatten(), target[i].flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + precision_recall.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + weights.append(confmat[1, 1] + confmat[1, 0]) + else: + scores, w = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + true, pred = remove_ignore_index(true, pred, ignore_index) + scores.append(sk_fn(true, pred)) + confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) + w.append(confmat[1, 1] + confmat[1, 0]) + precision_recall.append(np.stack(scores)) + weights.append(np.stack(w)) + if average == "micro": + return np.array(precision_recall) + res = np.stack(precision_recall, 0) + if average == "macro": + return res.mean(-1) + elif average == "weighted": + weights = np.stack(weights, 0).astype(float) + weights_norm = weights.sum(-1, keepdims=True) + weights_norm[weights_norm == 0] = 1.0 + return ((weights * res) / weights_norm).sum(-1) + elif average is None or average == "none": + return res + + def _sk_precision_recall_multilabel(preds, target, sk_fn, ignore_index, multidim_average, average): preds = preds.numpy() target = target.numpy() @@ -360,61 +420,8 @@ def _sk_precision_recall_multilabel(preds, target, sk_fn, ignore_index, multidim average=average, ) elif multidim_average == "global": - if average == "micro": - preds = preds.flatten() - target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) - return sk_fn(target, preds) - - precision_recall, weights = [], [] - for i in range(preds.shape[1]): - pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - precision_recall.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - weights.append(confmat[1, 1] + confmat[1, 0]) - res = np.stack(precision_recall, axis=0) - - if average == "macro": - return res.mean(0) - elif average == "weighted": - weights = np.stack(weights, 0).astype(float) - weights_norm = weights.sum(-1, keepdims=True) - weights_norm[weights_norm == 0] = 1.0 - return ((weights * res) / weights_norm).sum(-1) - elif average is None or average == "none": - return res - else: - precision_recall, weights = [], [] - for i in range(preds.shape[0]): - if average == "micro": - pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - precision_recall.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - weights.append(confmat[1, 1] + confmat[1, 0]) - else: - scores, w = [], [] - for j in range(preds.shape[1]): - pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) - scores.append(sk_fn(true, pred)) - confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) - w.append(confmat[1, 1] + confmat[1, 0]) - precision_recall.append(np.stack(scores)) - weights.append(np.stack(w)) - if average == "micro": - return np.array(precision_recall) - res = np.stack(precision_recall, 0) - if average == "macro": - return res.mean(-1) - elif average == "weighted": - weights = np.stack(weights, 0).astype(float) - weights_norm = weights.sum(-1, keepdims=True) - weights_norm[weights_norm == 0] = 1.0 - return ((weights * res) / weights_norm).sum(-1) - elif average is None or average == "none": - return res + return _sk_precision_recall_multilabel_global(preds, target, sk_fn, ignore_index, average) + return _sk_precision_recall_multilabel_local(preds, target, sk_fn, ignore_index, average) @pytest.mark.parametrize("input", _multilabel_cases) diff --git a/tests/unittests/classification/test_specificity.py b/tests/unittests/classification/test_specificity.py index 3771b397072..ca50d88ed49 100644 --- a/tests/unittests/classification/test_specificity.py +++ b/tests/unittests/classification/test_specificity.py @@ -11,402 +11,896 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from functools import partial -from typing import Callable, Optional import numpy as np import pytest import torch -from sklearn.metrics import multilabel_confusion_matrix +from scipy.special import expit as sigmoid +from sklearn.metrics import confusion_matrix as sk_confusion_matrix from torch import Tensor, tensor -from torchmetrics import Metric, Specificity -from torchmetrics.functional import specificity -from torchmetrics.functional.classification.stat_scores import _reduce_stat_scores -from torchmetrics.utilities.checks import _input_format_classification -from torchmetrics.utilities.enums import AverageMethod -from unittests.classification.inputs import _input_binary, _input_binary_prob -from unittests.classification.inputs import _input_multiclass as _input_mcls -from unittests.classification.inputs import _input_multiclass_prob as _input_mcls_prob -from unittests.classification.inputs import _input_multidim_multiclass as _input_mdmc -from unittests.classification.inputs import _input_multidim_multiclass_prob as _input_mdmc_prob -from unittests.classification.inputs import _input_multilabel as _input_mlb -from unittests.classification.inputs import _input_multilabel_prob as _input_mlb_prob +from torchmetrics.classification.specificity import BinarySpecificity, MulticlassSpecificity, MultilabelSpecificity +from torchmetrics.functional.classification.specificity import ( + binary_specificity, + multiclass_specificity, + multilabel_specificity, +) +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6 +from unittests.classification.inputs import _binary_cases, _multiclass_cases, _multilabel_cases from unittests.helpers import seed_all -from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from unittests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester, inject_ignore_index seed_all(42) -def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - sk_preds, sk_target = preds.numpy(), target.numpy() - - if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: - sk_preds = np.delete(sk_preds, ignore_index, 1) - sk_target = np.delete(sk_target, ignore_index, 1) - - if preds.shape[1] == 1 and reduce == "samples": - sk_target = sk_target.T - sk_preds = sk_preds.T - - sk_stats = multilabel_confusion_matrix( - sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 - ) - - if preds.shape[1] == 1 and reduce != "samples": - sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] +def _calc_specificity(tn, fp): + """safely calculate specificity.""" + denom = tn + fp + if np.isscalar(tn): + denom = 1.0 if denom == 0 else denom else: - sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] - - if reduce == "micro": - sk_stats = sk_stats.sum(axis=0, keepdims=True) - - sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) - - if reduce == "micro": - sk_stats = sk_stats[0] + denom[denom == 0] = 1.0 + return tn / denom - if reduce == "macro" and ignore_index is not None and preds.shape[1]: - sk_stats[ignore_index, :] = -1 - if reduce == "micro": - _, fp, tn, _, _ = sk_stats +def _sk_specificity_binary(preds, target, ignore_index, multidim_average): + if multidim_average == "global": + preds = preds.view(-1).numpy() + target = target.view(-1).numpy() else: - _, fp, tn, _ = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] - return fp, tn - - -def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k=None, mdmc_reduce=None, stats=None): - - if stats: - fp, tn = stats + preds = preds.numpy() + target = target.numpy() + + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + + if multidim_average == "global": + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + tn, fp, _, _ = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() + return _calc_specificity(tn, fp) else: - stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k) - fp, tn = stats - - fp, tn = tensor(fp), tensor(tn) - spec = _reduce_stat_scores( - numerator=tn, - denominator=tn + fp, - weights=None if reduce != "weighted" else tn + fp, - average=reduce, - mdmc_average=mdmc_reduce, - ) - if reduce in [None, "none"] and ignore_index is not None and preds.shape[1] > 1: - spec = spec.numpy() - spec = np.insert(spec, ignore_index, math.nan) - spec = tensor(spec) - - return spec - - -def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k=None): - preds, target, _ = _input_format_classification( - preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k - ) - - if mdmc_reduce == "global": - preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) - target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) - return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) - fp, tn = [], [] - stats = [] - - for i in range(preds.shape[0]): - pred_i = preds[i, ...].T - target_i = target[i, ...].T - fp_i, tn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k) - fp.append(fp_i) - tn.append(tn_i) + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, _, _ = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() + res.append(_calc_specificity(tn, fp)) + return np.stack(res) + + +@pytest.mark.parametrize("input", _binary_cases) +class TestBinarySpecificity(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("ddp", [False, True]) + def test_binary_specificity(self, ddp, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") - stats.append(fp) - stats.append(tn) - return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats) + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinarySpecificity, + sk_metric=partial(_sk_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={"threshold": THRESHOLD, "ignore_index": ignore_index, "multidim_average": multidim_average}, + ) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + def test_binary_specificity_functional(self, input, ignore_index, multidim_average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") -@pytest.mark.parametrize("metric, fn_metric", [(Specificity, specificity)]) -@pytest.mark.parametrize( - "average, mdmc_average, num_classes, ignore_index, match_str", - [ - ("wrong", None, None, None, "`average`"), - ("micro", "wrong", None, None, "`mdmc"), - ("macro", None, None, None, "number of classes"), - ("macro", None, 1, 0, "ignore_index"), - ], -) -def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): - with pytest.raises(ValueError, match=match_str): - metric( - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_specificity, + sk_metric=partial(_sk_specificity_binary, ignore_index=ignore_index, multidim_average=multidim_average), + metric_args={ + "threshold": THRESHOLD, + "ignore_index": ignore_index, + "multidim_average": multidim_average, + }, ) - with pytest.raises(ValueError, match=match_str): - fn_metric( - _input_binary.preds[0], - _input_binary.target[0], - average=average, - mdmc_average=mdmc_average, - num_classes=num_classes, - ignore_index=ignore_index, + def test_binary_specificity_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinarySpecificity, + metric_functional=binary_specificity, + metric_args={"threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_specificity_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinarySpecificity, + metric_functional=binary_specificity, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -def test_zero_division(metric_class, metric_fn): - """Test that zero_division works correctly (currently should just set to 0).""" - - preds = tensor([1, 2, 1, 1]) - target = tensor([0, 0, 0, 0]) - - cl_metric = metric_class(average="none", num_classes=3) - cl_metric(preds, target) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_specificity_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinarySpecificity, + metric_functional=binary_specificity, + metric_args={"threshold": THRESHOLD}, + dtype=dtype, + ) - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="none", num_classes=3) - assert result_cl[0] == result_fn[0] == 0 +def _sk_specificity_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + + if ignore_index is not None: + idx = target == ignore_index + target = target[~idx] + preds = preds[~idx] + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + if average == "micro": + return _calc_specificity(tn.sum(), fp.sum()) + + res = _calc_specificity(tn, fp) + if average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + +def _sk_specificity_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + if average == "micro": + res.append(_calc_specificity(tn.sum(), fp.sum())) + + r = _calc_specificity(tn, fp) + if average == "macro": + res.append(r.mean(0)) + elif average == "weighted": + w = tp + fn + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) + elif average is None or average == "none": + res.append(r) + return np.stack(res, 0) + + +def _sk_specificity_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _sk_specificity_multiclass_global(preds, target, ignore_index, average) + return _sk_specificity_multiclass_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multiclass_cases) +class TestMulticlassSpecificity(MetricTester): + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + @pytest.mark.parametrize("ddp", [True, False]) + def test_multiclass_specificity(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassSpecificity, + sk_metric=partial( + _sk_specificity_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -def test_no_support(metric_class, metric_fn): - """This tests a rare edge case, where there is only one class present. + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multiclass_specificity_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and target.ndim < 3: + pytest.skip("samplewise and non-multidim arrays are not valid") - in target, and ignore_index is set to exactly that class - and the - average method is equal to 'weighted'. + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_specificity, + sk_metric=partial( + _sk_specificity_multiclass, + ignore_index=ignore_index, + multidim_average=multidim_average, + average=average, + ), + metric_args={ + "ignore_index": ignore_index, + "multidim_average": multidim_average, + "average": average, + "num_classes": NUM_CLASSES, + }, + ) - This would mean that the sum of weights equals zero, and would, without - taking care of this case, return NaN. However, the reduction function - should catch that and set the metric to equal the value of zero_division - in this case (zero_division is for now not configurable and equals 0). - """ + def test_multiclass_specificity_differentiability(self, input): + preds, target = input + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassSpecificity, + metric_functional=multiclass_specificity, + metric_args={"num_classes": NUM_CLASSES}, + ) - preds = tensor([1, 1, 0, 0]) - target = tensor([0, 0, 0, 0]) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_specificity_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassSpecificity, + metric_functional=multiclass_specificity, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=1) - cl_metric(preds, target) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_specificity_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassSpecificity, + metric_functional=multiclass_specificity, + metric_args={"num_classes": NUM_CLASSES}, + dtype=dtype, + ) - result_cl = cl_metric.compute() - result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=1) - assert result_cl == result_fn == 0 +_mc_k_target = tensor([0, 1, 2]) +_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -@pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) -@pytest.mark.parametrize("ignore_index", [None, 0]) @pytest.mark.parametrize( - "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", + "k, preds, target, average, expected_spec", [ - (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_spec), - (_input_binary.preds, _input_binary.target, 1, False, None, _sk_spec), - (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_spec), - (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_spec), - (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_spec), - (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_spec), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), - (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), + (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), + (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), ], ) -class TestSpecificity(MetricTester): - @pytest.mark.parametrize("ddp", [False, True]) - @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_specificity_class( - self, - ddp: bool, - dist_sync_on_step: bool, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - # todo: `metric_fn` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") +def test_top_k(k: int, preds: Tensor, target: Tensor, average: str, expected_spec: Tensor): + """A simple test to check that top_k works as expected.""" + class_metric = MulticlassSpecificity(top_k=k, average=average, num_classes=3) + class_metric.update(preds, target) + + assert torch.equal(class_metric.compute(), expected_spec) + assert torch.equal(multiclass_specificity(preds, target, top_k=k, average=average, num_classes=3), expected_spec) + + +def _sk_specificity_multilabel_global(preds, target, ignore_index, average): + tns, fps = [], [] + for i in range(preds.shape[1]): + p, t = preds[:, i].flatten(), target[:, i].flatten() + if ignore_index is not None: + idx = t == ignore_index + t = t[~idx] + p = p[~idx] + tn, fp, fn, tp = sk_confusion_matrix(t, p, labels=[0, 1]).ravel() + tns.append(tn) + fps.append(fp) + + tn = np.array(tns) + fp = np.array(fps) + if average == "micro": + return _calc_specificity(tn.sum(), fp.sum()) + + res = _calc_specificity(tn, fp) + if average == "macro": + return res.mean(0) + elif average == "weighted": + w = res[:, 0] + res[:, 3] + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + +def _sk_specificity_multilabel_local(preds, target, ignore_index, average): + specificity = [] + for i in range(preds.shape[0]): + tns, fps = [], [] + for j in range(preds.shape[1]): + pred, true = preds[i, j], target[i, j] + if ignore_index is not None: + idx = true == ignore_index + true = true[~idx] + pred = pred[~idx] + tn, fp, _, _ = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() + tns.append(tn) + fps.append(fp) + tn = np.array(tns) + fp = np.array(fps) + if average == "micro": + specificity.append(_calc_specificity(tn.sum(), fp.sum())) + else: + specificity.append(_calc_specificity(tn, fp)) + + res = np.stack(specificity, 0) + if average == "micro" or average is None or average == "none": + return res + elif average == "macro": + return res.mean(-1) + elif average == "weighted": + w = res[:, 0, :] + res[:, 3, :] + return (res * (w / w.sum())[:, np.newaxis]).sum(-1) + elif average is None or average == "none": + return np.moveaxis(res, 1, -1) + + +def _sk_specificity_multilabel(preds, target, ignore_index, multidim_average, average): + preds = preds.numpy() + target = target.numpy() + if np.issubdtype(preds.dtype, np.floating): + if not ((0 < preds) & (preds < 1)).all(): + preds = sigmoid(preds) + preds = (preds >= THRESHOLD).astype(np.uint8) + preds = preds.reshape(*preds.shape[:2], -1) + target = target.reshape(*target.shape[:2], -1) + if multidim_average == "global": + return _sk_specificity_multilabel_global(preds, target, ignore_index, average) + return _sk_specificity_multilabel_local(preds, target, ignore_index, average) + + +@pytest.mark.parametrize("input", _multilabel_cases) +class TestMultilabelSpecificity(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_specificity(self, ddp, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") + if multidim_average == "samplewise" and ddp: + pytest.skip("samplewise and ddp give different order than non ddp") self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=metric_class, + metric_class=MultilabelSpecificity, sk_metric=partial( - sk_wrapper, - reduce=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_specificity_multilabel, ignore_index=ignore_index, - mdmc_reduce=mdmc_average, + multidim_average=multidim_average, + average=average, ), - dist_sync_on_step=dist_sync_on_step, metric_args={ - "num_classes": num_classes, - "average": average, + "num_labels": NUM_CLASSES, "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, }, ) - def test_specificity_fn( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - # todo: `metric_class` is unused - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + @pytest.mark.parametrize("ignore_index", [None, 0, -1]) + @pytest.mark.parametrize("multidim_average", ["global", "samplewise"]) + @pytest.mark.parametrize("average", ["micro", "macro", None]) + def test_multilabel_specificity_functional(self, input, ignore_index, multidim_average, average): + preds, target = input + if ignore_index == -1: + target = inject_ignore_index(target, ignore_index) + if multidim_average == "samplewise" and preds.ndim < 4: + pytest.skip("samplewise and non-multidim arrays are not valid") self.run_functional_metric_test( - preds, - target, - metric_functional=metric_fn, + preds=preds, + target=target, + metric_functional=multilabel_specificity, sk_metric=partial( - sk_wrapper, - reduce=average, - num_classes=num_classes, - multiclass=multiclass, + _sk_specificity_multilabel, ignore_index=ignore_index, - mdmc_reduce=mdmc_average, + multidim_average=multidim_average, + average=average, ), metric_args={ - "num_classes": num_classes, - "average": average, + "num_labels": NUM_CLASSES, "threshold": THRESHOLD, - "multiclass": multiclass, "ignore_index": ignore_index, - "mdmc_average": mdmc_average, + "multidim_average": multidim_average, + "average": average, }, ) - def test_accuracy_differentiability( - self, - preds: Tensor, - target: Tensor, - sk_wrapper: Callable, - metric_class: Metric, - metric_fn: Callable, - multiclass: Optional[bool], - num_classes: Optional[int], - average: str, - mdmc_average: Optional[str], - ignore_index: Optional[int], - ): - - if num_classes == 1 and average != "micro": - pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") - - if ignore_index is not None and preds.ndim == 2: - pytest.skip("Skipping ignore_index test with binary inputs.") - - if average == "weighted" and ignore_index is not None and mdmc_average is not None: - pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") - + def test_multilabel_specificity_differentiability(self, input): + preds, target = input self.run_differentiability_test( preds=preds, target=target, - metric_module=metric_class, - metric_functional=metric_fn, - metric_args={ - "num_classes": num_classes, - "average": average, - "threshold": THRESHOLD, - "multiclass": multiclass, - "ignore_index": ignore_index, - "mdmc_average": mdmc_average, - }, + metric_module=MultilabelSpecificity, + metric_functional=multilabel_specificity, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, ) + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_specificity_dtype_cpu(self, input, dtype): + preds, target = input + if dtype == torch.half and not _TORCH_GREATER_EQUAL_1_6: + pytest.xfail(reason="half support of core ops not support before pytorch v1.6") + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelSpecificity, + metric_functional=multilabel_specificity, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -_mc_k_target = tensor([0, 1, 2]) -_mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) -_ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) -_ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_specificity_dtype_gpu(self, input, dtype): + preds, target = input + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelSpecificity, + metric_functional=multilabel_specificity, + metric_args={"num_labels": NUM_CLASSES, "threshold": THRESHOLD}, + dtype=dtype, + ) -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -@pytest.mark.parametrize( - "k, preds, target, average, expected_spec", - [ - (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), - (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), - (1, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 2)), - (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6)), - ], -) -def test_top_k( - metric_class, - metric_fn, - k: int, - preds: Tensor, - target: Tensor, - average: str, - expected_spec: Tensor, -): - """A simple test to check that top_k works as expected. - - Just a sanity check, the tests in StatScores should already guarantee the correctness of results. - """ - - class_metric = metric_class(top_k=k, average=average, num_classes=3) - class_metric.update(preds, target) +# -------------------------- Old stuff -------------------------- + + +# def _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k): +# preds, target, _ = _input_format_classification( +# preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k +# ) +# sk_preds, sk_target = preds.numpy(), target.numpy() + +# if reduce != "macro" and ignore_index is not None and preds.shape[1] > 1: +# sk_preds = np.delete(sk_preds, ignore_index, 1) +# sk_target = np.delete(sk_target, ignore_index, 1) + +# if preds.shape[1] == 1 and reduce == "samples": +# sk_target = sk_target.T +# sk_preds = sk_preds.T + +# sk_stats = multilabel_confusion_matrix( +# sk_target, sk_preds, samplewise=(reduce == "samples") and preds.shape[1] != 1 +# ) + +# if preds.shape[1] == 1 and reduce != "samples": +# sk_stats = sk_stats[[1]].reshape(-1, 4)[:, [3, 1, 0, 2]] +# else: +# sk_stats = sk_stats.reshape(-1, 4)[:, [3, 1, 0, 2]] + +# if reduce == "micro": +# sk_stats = sk_stats.sum(axis=0, keepdims=True) + +# sk_stats = np.concatenate([sk_stats, sk_stats[:, [3]] + sk_stats[:, [0]]], 1) + +# if reduce == "micro": +# sk_stats = sk_stats[0] + +# if reduce == "macro" and ignore_index is not None and preds.shape[1]: +# sk_stats[ignore_index, :] = -1 + +# if reduce == "micro": +# _, fp, tn, _, _ = sk_stats +# else: +# _, fp, tn, _ = sk_stats[:, 0], sk_stats[:, 1], sk_stats[:, 2], sk_stats[:, 3] +# return fp, tn + + +# def _sk_spec(preds, target, reduce, num_classes, multiclass, ignore_index, top_k=None, mdmc_reduce=None, stats=None): + +# if stats: +# fp, tn = stats +# else: +# stats = _sk_stats_score(preds, target, reduce, num_classes, multiclass, ignore_index, top_k) +# fp, tn = stats + +# fp, tn = tensor(fp), tensor(tn) +# spec = _reduce_specificity( +# numerator=tn, +# denominator=tn + fp, +# weights=None if reduce != "weighted" else tn + fp, +# average=reduce, +# mdmc_average=mdmc_reduce, +# ) +# if reduce in [None, "none"] and ignore_index is not None and preds.shape[1] > 1: +# spec = spec.numpy() +# spec = np.insert(spec, ignore_index, math.nan) +# spec = tensor(spec) - assert torch.equal(class_metric.compute(), expected_spec) - assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), expected_spec) +# return spec -@pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) -@pytest.mark.parametrize( - "ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] -) -def test_class_not_present(metric_class, metric_fn, ignore_index, expected): - """This tests that when metric is computed per class and a given class is not present in both the `preds` and - `target`, the resulting score is `nan`.""" - preds = torch.tensor([0, 0, 0]) - target = torch.tensor([0, 0, 0]) - num_classes = 2 - - # test functional - result_fn = metric_fn(preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - assert torch.allclose(expected, result_fn, equal_nan=True) - - # test class - cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) - cl_metric(preds, target) - result_cl = cl_metric.compute() - assert torch.allclose(expected, result_cl, equal_nan=True) +# def _sk_spec_mdim_mcls(preds, target, reduce, mdmc_reduce, num_classes, multiclass, ignore_index, top_k=None): +# preds, target, _ = _input_format_classification( +# preds, target, threshold=THRESHOLD, num_classes=num_classes, multiclass=multiclass, top_k=top_k +# ) + +# if mdmc_reduce == "global": +# preds = torch.transpose(preds, 1, 2).reshape(-1, preds.shape[1]) +# target = torch.transpose(target, 1, 2).reshape(-1, target.shape[1]) +# return _sk_spec(preds, target, reduce, num_classes, False, ignore_index, top_k, mdmc_reduce) +# fp, tn = [], [] +# stats = [] + +# for i in range(preds.shape[0]): +# pred_i = preds[i, ...].T +# target_i = target[i, ...].T +# fp_i, tn_i = _sk_stats_score(pred_i, target_i, reduce, num_classes, False, ignore_index, top_k) +# fp.append(fp_i) +# tn.append(tn_i) + +# stats.append(fp) +# stats.append(tn) +# return _sk_spec(preds[0], target[0], reduce, num_classes, multiclass, ignore_index, top_k, mdmc_reduce, stats) + + +# @pytest.mark.parametrize("metric, fn_metric", [(Specificity, specificity)]) +# @pytest.mark.parametrize( +# "average, mdmc_average, num_classes, ignore_index, match_str", +# [ +# ("wrong", None, None, None, "`average`"), +# ("micro", "wrong", None, None, "`mdmc"), +# ("macro", None, None, None, "number of classes"), +# ("macro", None, 1, 0, "ignore_index"), +# ], +# ) +# def test_wrong_params(metric, fn_metric, average, mdmc_average, num_classes, ignore_index, match_str): +# with pytest.raises(ValueError, match=match_str): +# metric( +# average=average, +# mdmc_average=mdmc_average, +# num_classes=num_classes, +# ignore_index=ignore_index, +# ) + +# with pytest.raises(ValueError, match=match_str): +# fn_metric( +# _input_binary.preds[0], +# _input_binary.target[0], +# average=average, +# mdmc_average=mdmc_average, +# num_classes=num_classes, +# ignore_index=ignore_index, +# ) + + +# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +# def test_zero_division(metric_class, metric_fn): +# """Test that zero_division works correctly (currently should just set to 0).""" + +# preds = tensor([1, 2, 1, 1]) +# target = tensor([0, 0, 0, 0]) + +# cl_metric = metric_class(average="none", num_classes=3) +# cl_metric(preds, target) + +# result_cl = cl_metric.compute() +# result_fn = metric_fn(preds, target, average="none", num_classes=3) + +# assert result_cl[0] == result_fn[0] == 0 + + +# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +# def test_no_support(metric_class, metric_fn): +# """This tests a rare edge case, where there is only one class present. + +# in target, and ignore_index is set to exactly that class - and the +# average method is equal to 'weighted'. + +# This would mean that the sum of weights equals zero, and would, without +# taking care of this case, return NaN. However, the reduction function +# should catch that and set the metric to equal the value of zero_division +# in this case (zero_division is for now not configurable and equals 0). +# """ + +# preds = tensor([1, 1, 0, 0]) +# target = tensor([0, 0, 0, 0]) + +# cl_metric = metric_class(average="weighted", num_classes=2, ignore_index=1) +# cl_metric(preds, target) + +# result_cl = cl_metric.compute() +# result_fn = metric_fn(preds, target, average="weighted", num_classes=2, ignore_index=1) + +# assert result_cl == result_fn == 0 + + +# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +# @pytest.mark.parametrize("average", ["micro", "macro", None, "weighted", "samples"]) +# @pytest.mark.parametrize("ignore_index", [None, 0]) +# @pytest.mark.parametrize( +# "preds, target, num_classes, multiclass, mdmc_average, sk_wrapper", +# [ +# (_input_binary_prob.preds, _input_binary_prob.target, 1, None, None, _sk_spec), +# (_input_binary.preds, _input_binary.target, 1, False, None, _sk_spec), +# (_input_mlb_prob.preds, _input_mlb_prob.target, NUM_CLASSES, None, None, _sk_spec), +# (_input_mlb.preds, _input_mlb.target, NUM_CLASSES, False, None, _sk_spec), +# (_input_mcls_prob.preds, _input_mcls_prob.target, NUM_CLASSES, None, None, _sk_spec), +# (_input_mcls.preds, _input_mcls.target, NUM_CLASSES, None, None, _sk_spec), +# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "global", _sk_spec_mdim_mcls), +# (_input_mdmc.preds, _input_mdmc.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), +# (_input_mdmc_prob.preds, _input_mdmc_prob.target, NUM_CLASSES, None, "samplewise", _sk_spec_mdim_mcls), +# ], +# ) +# class TestSpecificity(MetricTester): +# @pytest.mark.parametrize("ddp", [False, True]) +# @pytest.mark.parametrize("dist_sync_on_step", [True, False]) +# def test_specificity_class( +# self, +# ddp: bool, +# dist_sync_on_step: bool, +# preds: Tensor, +# target: Tensor, +# sk_wrapper: Callable, +# metric_class: Metric, +# metric_fn: Callable, +# multiclass: Optional[bool], +# num_classes: Optional[int], +# average: str, +# mdmc_average: Optional[str], +# ignore_index: Optional[int], +# ): +# # todo: `metric_fn` is unused +# if num_classes == 1 and average != "micro": +# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# if average == "weighted" and ignore_index is not None and mdmc_average is not None: +# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + +# self.run_class_metric_test( +# ddp=ddp, +# preds=preds, +# target=target, +# metric_class=metric_class, +# sk_metric=partial( +# sk_wrapper, +# reduce=average, +# num_classes=num_classes, +# multiclass=multiclass, +# ignore_index=ignore_index, +# mdmc_reduce=mdmc_average, +# ), +# dist_sync_on_step=dist_sync_on_step, +# metric_args={ +# "num_classes": num_classes, +# "average": average, +# "threshold": THRESHOLD, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "mdmc_average": mdmc_average, +# }, +# ) + +# def test_specificity_fn( +# self, +# preds: Tensor, +# target: Tensor, +# sk_wrapper: Callable, +# metric_class: Metric, +# metric_fn: Callable, +# multiclass: Optional[bool], +# num_classes: Optional[int], +# average: str, +# mdmc_average: Optional[str], +# ignore_index: Optional[int], +# ): +# # todo: `metric_class` is unused +# if num_classes == 1 and average != "micro": +# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# if average == "weighted" and ignore_index is not None and mdmc_average is not None: +# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + +# self.run_functional_metric_test( +# preds, +# target, +# metric_functional=metric_fn, +# sk_metric=partial( +# sk_wrapper, +# reduce=average, +# num_classes=num_classes, +# multiclass=multiclass, +# ignore_index=ignore_index, +# mdmc_reduce=mdmc_average, +# ), +# metric_args={ +# "num_classes": num_classes, +# "average": average, +# "threshold": THRESHOLD, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "mdmc_average": mdmc_average, +# }, +# ) + +# def test_accuracy_differentiability( +# self, +# preds: Tensor, +# target: Tensor, +# sk_wrapper: Callable, +# metric_class: Metric, +# metric_fn: Callable, +# multiclass: Optional[bool], +# num_classes: Optional[int], +# average: str, +# mdmc_average: Optional[str], +# ignore_index: Optional[int], +# ): + +# if num_classes == 1 and average != "micro": +# pytest.skip("Only test binary data for 'micro' avg (equivalent of 'binary' in sklearn)") + +# if ignore_index is not None and preds.ndim == 2: +# pytest.skip("Skipping ignore_index test with binary inputs.") + +# if average == "weighted" and ignore_index is not None and mdmc_average is not None: +# pytest.skip("Ignore special case where we are ignoring entire sample for 'weighted' average") + +# self.run_differentiability_test( +# preds=preds, +# target=target, +# metric_module=metric_class, +# metric_functional=metric_fn, +# metric_args={ +# "num_classes": num_classes, +# "average": average, +# "threshold": THRESHOLD, +# "multiclass": multiclass, +# "ignore_index": ignore_index, +# "mdmc_average": mdmc_average, +# }, +# ) + + +# _mc_k_target = tensor([0, 1, 2]) +# _mc_k_preds = tensor([[0.35, 0.4, 0.25], [0.1, 0.5, 0.4], [0.2, 0.1, 0.7]]) +# _ml_k_target = tensor([[0, 1, 0], [1, 1, 0], [0, 0, 0]]) +# _ml_k_preds = tensor([[0.9, 0.2, 0.75], [0.1, 0.7, 0.8], [0.6, 0.1, 0.7]]) + + +# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +# @pytest.mark.parametrize( +# "k, preds, target, average, expected_spec", +# [ +# (1, _mc_k_preds, _mc_k_target, "micro", tensor(5 / 6)), +# (2, _mc_k_preds, _mc_k_target, "micro", tensor(1 / 2)), +# (1, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 2)), +# (2, _ml_k_preds, _ml_k_target, "micro", tensor(1 / 6)), +# ], +# ) +# def test_top_k( +# metric_class, +# metric_fn, +# k: int, +# preds: Tensor, +# target: Tensor, +# average: str, +# expected_spec: Tensor, +# ): +# """A simple test to check that top_k works as expected. + +# Just a sanity check, the tests in Specificity should already guarantee the correctness of results. +# """ + +# class_metric = metric_class(top_k=k, average=average, num_classes=3) +# class_metric.update(preds, target) + +# assert torch.equal(class_metric.compute(), expected_spec) +# assert torch.equal(metric_fn(preds, target, top_k=k, average=average, num_classes=3), expected_spec) + + +# @pytest.mark.parametrize("metric_class, metric_fn", [(Specificity, specificity)]) +# @pytest.mark.parametrize( +# "ignore_index, expected", [(None, torch.tensor([0.0, np.nan])), (0, torch.tensor([np.nan, np.nan]))] +# ) +# def test_class_not_present(metric_class, metric_fn, ignore_index, expected): +# """This tests that when metric is computed per class and a given class is not present in both the `preds` and +# `target`, the resulting score is `nan`.""" +# preds = torch.tensor([0, 0, 0]) +# target = torch.tensor([0, 0, 0]) +# num_classes = 2 + +# # test functional +# result_fn = metric_fn( +# preds, target, average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index +# ) +# assert torch.allclose(expected, result_fn, equal_nan=True) + +# # test class +# cl_metric = metric_class(average=AverageMethod.NONE, num_classes=num_classes, ignore_index=ignore_index) +# cl_metric(preds, target) +# result_cl = cl_metric.compute() +# assert torch.allclose(expected, result_cl, equal_nan=True) diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 7f8007be7e4..4b1cf2046e6 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -145,55 +145,61 @@ def test_binary_stat_scores_dtype_gpu(self, input, dtype): ) -def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): - if preds.ndim == target.ndim + 1: - preds = torch.argmax(preds, 1) - if multidim_average == "global": - preds = preds.numpy().flatten() - target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) - confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) +def _sk_stat_scores_multiclass_global(preds, target, ignore_index, average): + preds = preds.numpy().flatten() + target = target.numpy().flatten() + target, preds = remove_ignore_index(target, preds, ignore_index) + confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) + tp = np.diag(confmat) + fp = confmat.sum(0) - tp + fn = confmat.sum(1) - tp + tn = confmat.sum() - (fp + fn + tp) + + res = np.stack([tp, fp, tn, fn, tp + fn], 1) + if average == "micro": + return res.sum(0) + elif average == "macro": + return res.mean(0) + elif average == "weighted": + w = tp + fn + return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + elif average is None or average == "none": + return res + + +def _sk_stat_scores_multiclass_local(preds, target, ignore_index, average): + preds = preds.numpy() + target = target.numpy() + + res = [] + for pred, true in zip(preds, target): + pred = pred.flatten() + true = true.flatten() + true, pred = remove_ignore_index(true, pred, ignore_index) + confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) tp = np.diag(confmat) fp = confmat.sum(0) - tp fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) - - res = np.stack([tp, fp, tn, fn, tp + fn], 1) + r = np.stack([tp, fp, tn, fn, tp + fn], 1) if average == "micro": - return res.sum(0) + res.append(r.sum(0)) elif average == "macro": - return res.mean(0) + res.append(r.mean(0)) elif average == "weighted": w = tp + fn - return (res * (w / w.sum()).reshape(-1, 1)).sum(0) + res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) elif average is None or average == "none": - return res + res.append(r) + return np.stack(res, 0) - else: - preds = preds.numpy() - target = target.numpy() - res = [] - for pred, true in zip(preds, target): - pred = pred.flatten() - true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) - confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) - tp = np.diag(confmat) - fp = confmat.sum(0) - tp - fn = confmat.sum(1) - tp - tn = confmat.sum() - (fp + fn + tp) - r = np.stack([tp, fp, tn, fn, tp + fn], 1) - if average == "micro": - res.append(r.sum(0)) - elif average == "macro": - res.append(r.mean(0)) - elif average == "weighted": - w = tp + fn - res.append((r * (w / w.sum()).reshape(-1, 1)).sum(0)) - elif average is None or average == "none": - res.append(r) - return np.stack(res, 0) +def _sk_stat_scores_multiclass(preds, target, ignore_index, multidim_average, average): + if preds.ndim == target.ndim + 1: + preds = torch.argmax(preds, 1) + if multidim_average == "global": + return _sk_stat_scores_multiclass_global(preds, target, ignore_index, average) + return _sk_stat_scores_multiclass_local(preds, target, ignore_index, average) @pytest.mark.parametrize("input", _multiclass_cases)