diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a5b6d85c6c..6c942cf3883 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Deprecated +- Renamed IoU -> Jaccard Index ([#662](https://github.com/PyTorchLightning/metrics/pull/662)) + ### Removed diff --git a/docs/source/references/functional.rst b/docs/source/references/functional.rst index 2696a250511..38a752d8734 100644 --- a/docs/source/references/functional.rst +++ b/docs/source/references/functional.rst @@ -144,10 +144,10 @@ hinge [func] .. autofunction:: torchmetrics.functional.hinge :noindex: -iou [func] -~~~~~~~~~~ +jaccard_index [func] +~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: torchmetrics.functional.iou +.. autofunction:: torchmetrics.functional.jaccard_index :noindex: kl_divergence [func] diff --git a/docs/source/references/modules.rst b/docs/source/references/modules.rst index 44e7aedc6a3..a9ea013894c 100644 --- a/docs/source/references/modules.rst +++ b/docs/source/references/modules.rst @@ -302,10 +302,10 @@ Hinge .. autoclass:: torchmetrics.Hinge :noindex: -IoU -~~~ +JaccardIndex +~~~~~~~~~~~~ -.. autoclass:: torchmetrics.IoU +.. autoclass:: torchmetrics.JaccardIndex :noindex: KLDivergence diff --git a/tests/classification/test_iou.py b/tests/classification/test_jaccard.py similarity index 77% rename from tests/classification/test_iou.py rename to tests/classification/test_jaccard.py index 2a83ae034a4..3b6260d0b92 100644 --- a/tests/classification/test_iou.py +++ b/tests/classification/test_jaccard.py @@ -27,60 +27,60 @@ from tests.classification.inputs import _input_multilabel as _input_mlb from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester -from torchmetrics.classification.iou import IoU -from torchmetrics.functional import iou +from torchmetrics.classification.jaccard import JaccardIndex +from torchmetrics.functional import jaccard_index -def _sk_iou_binary_prob(preds, target, average=None): +def _sk_jaccard_binary_prob(preds, target, average=None): sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_binary(preds, target, average=None): +def _sk_jaccard_binary(preds, target, average=None): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_multilabel_prob(preds, target, average=None): +def _sk_jaccard_multilabel_prob(preds, target, average=None): sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_multilabel(preds, target, average=None): +def _sk_jaccard_multilabel(preds, target, average=None): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_multiclass_prob(preds, target, average=None): +def _sk_jaccard_multiclass_prob(preds, target, average=None): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_multiclass(preds, target, average=None): +def _sk_jaccard_multiclass(preds, target, average=None): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_multidim_multiclass_prob(preds, target, average=None): +def _sk_jaccard_multidim_multiclass_prob(preds, target, average=None): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() sk_target = target.view(-1).numpy() return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -def _sk_iou_multidim_multiclass(preds, target, average=None): +def _sk_jaccard_multidim_multiclass(preds, target, average=None): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -91,47 +91,47 @@ def _sk_iou_multidim_multiclass(preds, target, average=None): @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ - (_input_binary_prob.preds, _input_binary_prob.target, _sk_iou_binary_prob, 2), - (_input_binary.preds, _input_binary.target, _sk_iou_binary, 2), - (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_iou_multilabel_prob, 2), - (_input_mlb.preds, _input_mlb.target, _sk_iou_multilabel, 2), - (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_iou_multiclass_prob, NUM_CLASSES), - (_input_mcls.preds, _input_mcls.target, _sk_iou_multiclass, NUM_CLASSES), - (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_iou_multidim_multiclass_prob, NUM_CLASSES), - (_input_mdmc.preds, _input_mdmc.target, _sk_iou_multidim_multiclass, NUM_CLASSES), + (_input_binary_prob.preds, _input_binary_prob.target, _sk_jaccard_binary_prob, 2), + (_input_binary.preds, _input_binary.target, _sk_jaccard_binary, 2), + (_input_mlb_prob.preds, _input_mlb_prob.target, _sk_jaccard_multilabel_prob, 2), + (_input_mlb.preds, _input_mlb.target, _sk_jaccard_multilabel, 2), + (_input_mcls_prob.preds, _input_mcls_prob.target, _sk_jaccard_multiclass_prob, NUM_CLASSES), + (_input_mcls.preds, _input_mcls.target, _sk_jaccard_multiclass, NUM_CLASSES), + (_input_mdmc_prob.preds, _input_mdmc_prob.target, _sk_jaccard_multidim_multiclass_prob, NUM_CLASSES), + (_input_mdmc.preds, _input_mdmc.target, _sk_jaccard_multidim_multiclass, NUM_CLASSES), ], ) -class TestIoU(MetricTester): +class TestJaccardIndex(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_iou(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_class_metric_test( ddp=ddp, preds=preds, target=target, - metric_class=IoU, + metric_class=JaccardIndex, sk_metric=partial(sk_metric, average=average), dist_sync_on_step=dist_sync_on_step, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, ) - def test_iou_functional(self, reduction, preds, target, sk_metric, num_classes): + def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes): average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_functional_metric_test( preds, target, - metric_functional=iou, + metric_functional=jaccard_index, sk_metric=partial(sk_metric, average=average), metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, ) - def test_iou_differentiability(self, reduction, preds, target, sk_metric, num_classes): + def test_jaccard_differentiability(self, reduction, preds, target, sk_metric, num_classes): self.run_differentiability_test( preds=preds, target=target, - metric_module=IoU, - metric_functional=iou, + metric_module=JaccardIndex, + metric_functional=jaccard_index, metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, ) @@ -147,18 +147,18 @@ def test_iou_differentiability(self, reduction, preds, target, sk_metric, num_cl (True, "none", 0, Tensor([2 / 3, 1 / 2])), ], ) -def test_iou(half_ones, reduction, ignore_index, expected): +def test_jaccard(half_ones, reduction, ignore_index, expected): preds = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: preds[:60] = 1 - iou_val = iou( + jaccard_val = jaccard_index( preds=preds, target=target, ignore_index=ignore_index, reduction=reduction, ) - assert torch.allclose(iou_val, expected, atol=1e-9) + assert torch.allclose(jaccard_val, expected, atol=1e-9) # test `absent_score` @@ -194,8 +194,8 @@ def test_iou(half_ones, reduction, ignore_index, expected): ([0, 2], [0, 2], 0, 1.0, 3, [1.0, 1.0]), ], ) -def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): - iou_val = iou( +def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): + jaccard_val = jaccard_index( preds=tensor(pred), target=tensor(target), ignore_index=ignore_index, @@ -203,7 +203,7 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, num_classes=num_classes, reduction="none", ) - assert torch.allclose(iou_val, tensor(expected).to(iou_val)) + assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) # example data taken from @@ -224,12 +224,12 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), ], ) -def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): - iou_val = iou( +def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): + jaccard_val = jaccard_index( preds=tensor(pred), target=tensor(target), ignore_index=ignore_index, num_classes=num_classes, reduction=reduction, ) - assert torch.allclose(iou_val, tensor(expected).to(iou_val)) + assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) diff --git a/torchmetrics/__init__.py b/torchmetrics/__init__.py index b4d188239cd..f6018e39037 100644 --- a/torchmetrics/__init__.py +++ b/torchmetrics/__init__.py @@ -14,7 +14,7 @@ from torchmetrics import functional # noqa: E402 from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402 from torchmetrics.audio import PIT, SDR, SI_SDR, SI_SNR, SNR # noqa: E402 -from torchmetrics.classification import ( # noqa: E402 +from torchmetrics.classification import ( # noqa: E402, F401 AUC, AUROC, F1, @@ -31,6 +31,7 @@ HammingDistance, Hinge, IoU, + JaccardIndex, KLDivergence, MatthewsCorrcoef, Precision, @@ -101,7 +102,7 @@ "FBeta", "HammingDistance", "Hinge", - "IoU", + "JaccardIndex", "KLDivergence", "MatthewsCorrcoef", "MaxMetric", diff --git a/torchmetrics/classification/__init__.py b/torchmetrics/classification/__init__.py index 0ed2d3d8d8b..f76fd097571 100644 --- a/torchmetrics/classification/__init__.py +++ b/torchmetrics/classification/__init__.py @@ -25,6 +25,7 @@ from torchmetrics.classification.hamming_distance import HammingDistance # noqa: F401 from torchmetrics.classification.hinge import Hinge # noqa: F401 from torchmetrics.classification.iou import IoU # noqa: F401 +from torchmetrics.classification.jaccard import JaccardIndex # noqa: F401 from torchmetrics.classification.kl_divergence import KLDivergence # noqa: F401 from torchmetrics.classification.matthews_corrcoef import MatthewsCorrcoef # noqa: F401 from torchmetrics.classification.precision_recall import Precision, Recall # noqa: F401 diff --git a/torchmetrics/classification/iou.py b/torchmetrics/classification/iou.py index 789d57e8523..f89e96ad0c5 100644 --- a/torchmetrics/classification/iou.py +++ b/torchmetrics/classification/iou.py @@ -12,60 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Optional +from warnings import warn import torch -from torch import Tensor -from torchmetrics.classification.confusion_matrix import ConfusionMatrix -from torchmetrics.functional.classification.iou import _iou_from_confmat +from torchmetrics.classification.jaccard import JaccardIndex -class IoU(ConfusionMatrix): +class IoU(JaccardIndex): r""" Computes Intersection over union, or `Jaccard index`_: - .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} - - Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values. - They may be subject to conversion from input data (see description below). Note that it is different from box IoU. - - Works with binary, multiclass and multi-label data. - Accepts probabilities from a model output or integer class values in prediction. - Works with multi-dimensional preds and target. - - Forward accepts - - - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes - - ``target`` (long tensor): ``(N, ...)`` - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. - - Args: - num_classes: Number of classes in the dataset. - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1]. By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of the class index were present in - `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, - [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. - threshold: - Threshold value for binary or multi-label probabilities. - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - compute_on_step: - Forward only calls ``update()`` and return None if this is set to False. - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) + .. deprecated:: v0.7 + Use :class:`torchmetrics.JaccardIndex`. Will be removed in v0.8. Example: >>> from torchmetrics import IoU @@ -77,8 +36,6 @@ class IoU(ConfusionMatrix): tensor(0.9660) """ - is_differentiable = False - higher_is_better = True def __init__( self, @@ -91,18 +48,14 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ) -> None: + warn("`IoU` was renamed to `JaccardIndex` in v0.7 and it will be removed in v0.8", DeprecationWarning) super().__init__( num_classes=num_classes, - normalize=None, + ignore_index=ignore_index, + absent_score=absent_score, threshold=threshold, + reduction=reduction, compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) - self.reduction = reduction - self.ignore_index = ignore_index - self.absent_score = absent_score - - def compute(self) -> Tensor: - """Computes intersection over union (IoU)""" - return _iou_from_confmat(self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction) diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py new file mode 100644 index 00000000000..3535f98d4bc --- /dev/null +++ b/torchmetrics/classification/jaccard.py @@ -0,0 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional + +import torch +from torch import Tensor + +from torchmetrics.classification.confusion_matrix import ConfusionMatrix +from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat + + +class JaccardIndex(ConfusionMatrix): + r""" + Computes Intersection over union, or `Jaccard index`_: + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Where: :math:`A` and :math:`B` are both tensors of the same size, containing integer class values. + They may be subject to conversion from input data (see description below). Note that it is different from box IoU. + + Works with binary, multiclass and multi-label data. + Accepts probabilities from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1]. By default, no index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of the class index were present in + `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. + threshold: + Threshold value for binary or multi-label probabilities. + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + >>> from torchmetrics import JaccardIndex + >>> target = torch.randint(0, 2, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> jaccard = JaccardIndex(num_classes=2) + >>> jaccard(pred, target) + tensor(0.9660) + + """ + is_differentiable = False + higher_is_better = True + + def __init__( + self, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + reduction: str = "elementwise_mean", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ) -> None: + super().__init__( + num_classes=num_classes, + normalize=None, + threshold=threshold, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + self.reduction = reduction + self.ignore_index = ignore_index + self.absent_score = absent_score + + def compute(self) -> Tensor: + """Computes intersection over union (IoU)""" + return _jaccard_from_confmat( + self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction + ) diff --git a/torchmetrics/functional/__init__.py b/torchmetrics/functional/__init__.py index 380185c3a06..167cfa00247 100644 --- a/torchmetrics/functional/__init__.py +++ b/torchmetrics/functional/__init__.py @@ -27,7 +27,8 @@ from torchmetrics.functional.classification.f_beta import f1, fbeta from torchmetrics.functional.classification.hamming_distance import hamming_distance from torchmetrics.functional.classification.hinge import hinge -from torchmetrics.functional.classification.iou import iou +from torchmetrics.functional.classification.iou import iou # noqa: F401 +from torchmetrics.functional.classification.jaccard import jaccard_index from torchmetrics.functional.classification.kl_divergence import kl_divergence from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall @@ -95,7 +96,7 @@ "hamming_distance", "hinge", "image_gradients", - "iou", + "jaccard_index", "kl_divergence", "matthews_corrcoef", "mean_absolute_error", diff --git a/torchmetrics/functional/classification/__init__.py b/torchmetrics/functional/classification/__init__.py index 3731a3dd26c..f65507f5862 100644 --- a/torchmetrics/functional/classification/__init__.py +++ b/torchmetrics/functional/classification/__init__.py @@ -22,7 +22,7 @@ from torchmetrics.functional.classification.f_beta import f1, fbeta # noqa: F401 from torchmetrics.functional.classification.hamming_distance import hamming_distance # noqa: F401 from torchmetrics.functional.classification.hinge import hinge # noqa: F401 -from torchmetrics.functional.classification.iou import iou # noqa: F401 +from torchmetrics.functional.classification.jaccard import jaccard_index # noqa: F401 from torchmetrics.functional.classification.kl_divergence import kl_divergence # noqa: F401 from torchmetrics.functional.classification.matthews_corrcoef import matthews_corrcoef # noqa: F401 from torchmetrics.functional.classification.precision_recall import precision, precision_recall, recall # noqa: F401 diff --git a/torchmetrics/functional/classification/iou.py b/torchmetrics/functional/classification/iou.py index f89650320b5..b78952ad00b 100644 --- a/torchmetrics/functional/classification/iou.py +++ b/torchmetrics/functional/classification/iou.py @@ -12,58 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional +from warnings import warn import torch from torch import Tensor -from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update -from torchmetrics.utilities.data import get_num_classes -from torchmetrics.utilities.distributed import reduce - - -def _iou_from_confmat( - confmat: Tensor, - num_classes: int, - ignore_index: Optional[int] = None, - absent_score: float = 0.0, - reduction: str = "elementwise_mean", -) -> Tensor: - """Computes the intersection over union from confusion matrix. - - Args: - confmat: Confusion matrix without normalization - num_classes: Number of classes for a given prediction and target tensor - ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute - to the returned score, regardless of reduction method. - absent_score: score to use for an individual class, if no instances of the class index were present in `pred` - AND no instances of the class index were present in `target`. - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - """ - - # Remove the ignored class index from the scores. - if ignore_index is not None and 0 <= ignore_index < num_classes: - confmat[ignore_index] = 0.0 - - intersection = torch.diag(confmat) - union = confmat.sum(0) + confmat.sum(1) - intersection - - # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. - scores = intersection.float() / union.float() - scores[union == 0] = absent_score - - if ignore_index is not None and 0 <= ignore_index < num_classes: - scores = torch.cat( - [ - scores[:ignore_index], - scores[ignore_index + 1 :], - ] - ) - - return reduce(scores, reduction=reduction) +from torchmetrics.functional.classification.jaccard import jaccard_index def iou( @@ -76,58 +30,22 @@ def iou( reduction: str = "elementwise_mean", ) -> Tensor: r""" - Computes `Jaccard index`_ - - .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} - - Where: :math:`A` and :math:`B` are both tensors of the same size, - containing integer class values. They may be subject to conversion from - input data (see description below). - - Note that it is different from box IoU. - - If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument - to convert into integer labels. This is the case for binary and multi-label probabilities. - - If pred has an extra dimension as in the case of multi-class scores we - perform an argmax on ``dim=1``. - - Args: - preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]`` - target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]`` - ignore_index: optional int specifying a target class to ignore. If given, - this class index does not contribute to the returned score, regardless - of reduction method. Has no effect if given an int that is not in the - range [0, num_classes-1], where num_classes is either given or derived - from pred and target. By default, no index is ignored, and all classes are used. - absent_score: score to use for an individual class, if no instances of - the class index were present in `pred` AND no instances of the class - index were present in `target`. For example, if we have 3 classes, - [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be - assigned the `absent_score`. - threshold: - Threshold value for binary or multi-label probabilities. default: 0.5 - num_classes: - Optionally specify the number of classes - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - - Return: - IoU score: Tensor containing single value if reduction is - 'elementwise_mean', or number of classes if reduction is 'none' - - Example: - >>> from torchmetrics.functional import iou - >>> target = torch.randint(0, 2, (10, 25, 25)) - >>> pred = torch.tensor(target) - >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] - >>> iou(pred, target) - tensor(0.9660) + Computes `Jaccard index`_ + + .. deprecated:: v0.7 + Use :func:`torchmetrics.functional.jaccard_index`. Will be removed in v0.8. + + Example: + >>> from torchmetrics.functional import iou + >>> target = torch.randint(0, 2, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> iou(pred, target) + tensor(0.9660) """ - - num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes) - confmat = _confusion_matrix_update(preds, target, num_classes, threshold) - return _iou_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) + warn( + "Function `iou` is renamed in v0.7 and will be removed in v0.8." + " Use `functional.functional.jaccard_index` instead.", + DeprecationWarning, + ) + return jaccard_index(preds, target, ignore_index, absent_score, threshold, num_classes, reduction) diff --git a/torchmetrics/functional/classification/jaccard.py b/torchmetrics/functional/classification/jaccard.py new file mode 100644 index 00000000000..c5cc3ab6999 --- /dev/null +++ b/torchmetrics/functional/classification/jaccard.py @@ -0,0 +1,133 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +import torch +from torch import Tensor + +from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update +from torchmetrics.utilities.data import get_num_classes +from torchmetrics.utilities.distributed import reduce + + +def _jaccard_from_confmat( + confmat: Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + reduction: str = "elementwise_mean", +) -> Tensor: + """Computes the intersection over union from confusion matrix. + + Args: + confmat: Confusion matrix without normalization + num_classes: Number of classes for a given prediction and target tensor + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. + absent_score: score to use for an individual class, if no instances of the class index were present in `pred` + AND no instances of the class index were present in `target`. + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + """ + + # Remove the ignored class index from the scores. + if ignore_index is not None and 0 <= ignore_index < num_classes: + confmat[ignore_index] = 0.0 + + intersection = torch.diag(confmat) + union = confmat.sum(0) + confmat.sum(1) - intersection + + # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. + scores = intersection.float() / union.float() + scores[union == 0] = absent_score + + if ignore_index is not None and 0 <= ignore_index < num_classes: + scores = torch.cat( + [ + scores[:ignore_index], + scores[ignore_index + 1 :], + ] + ) + + return reduce(scores, reduction=reduction) + + +def jaccard_index( + preds: Tensor, + target: Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + threshold: float = 0.5, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", +) -> Tensor: + r""" + Computes `Jaccard index`_ + + .. math:: J(A,B) = \frac{|A\cap B|}{|A\cup B|} + + Where: :math:`A` and :math:`B` are both tensors of the same size, + containing integer class values. They may be subject to conversion from + input data (see description below). + + Note that it is different from box IoU. + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument + to convert into integer labels. This is the case for binary and multi-label probabilities. + + If pred has an extra dimension as in the case of multi-class scores we + perform an argmax on ``dim=1``. + + Args: + preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]`` + target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]`` + ignore_index: optional int specifying a target class to ignore. If given, + this class index does not contribute to the returned score, regardless + of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1], where num_classes is either given or derived + from pred and target. By default, no index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of + the class index were present in `pred` AND no instances of the class + index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be + assigned the `absent_score`. + threshold: + Threshold value for binary or multi-label probabilities. default: 0.5 + num_classes: + Optionally specify the number of classes + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + Return: + IoU score: Tensor containing single value if reduction is + 'elementwise_mean', or number of classes if reduction is 'none' + + Example: + >>> from torchmetrics.functional import jaccard_index + >>> target = torch.randint(0, 2, (10, 25, 25)) + >>> pred = torch.tensor(target) + >>> pred[2:5, 7:13, 9:15] = 1 - pred[2:5, 7:13, 9:15] + >>> jaccard_index(pred, target) + tensor(0.9660) + """ + + num_classes = get_num_classes(preds=preds, target=target, num_classes=num_classes) + confmat = _confusion_matrix_update(preds, target, num_classes, threshold) + return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction)