diff --git a/CHANGELOG.md b/CHANGELOG.md index e96df58d964..1314874c59c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,7 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- +- Renamed `reduction` argument to `average` in Jaccard score and added additional options ([#874](https://github.com/PyTorchLightning/metrics/pull/874)) ### Deprecated diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index e9b8deb438d..80dc44c0b2a 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -87,7 +87,7 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None): return sk_jaccard_score(y_true=sk_target, y_pred=sk_preds, average=average) -@pytest.mark.parametrize("reduction", ["elementwise_mean", "none"]) +@pytest.mark.parametrize("average", [None, "macro", "micro", "weighted"]) @pytest.mark.parametrize( "preds, target, sk_metric, num_classes", [ @@ -104,8 +104,8 @@ def _sk_jaccard_multidim_multiclass(preds, target, average=None): class TestJaccardIndex(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): - average = "macro" if reduction == "elementwise_mean" else None # convert tags + def test_jaccard(self, average, preds, target, sk_metric, num_classes, ddp, dist_sync_on_step): + # average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_class_metric_test( ddp=ddp, preds=preds, @@ -113,41 +113,41 @@ def test_jaccard(self, reduction, preds, target, sk_metric, num_classes, ddp, di metric_class=JaccardIndex, sk_metric=partial(sk_metric, average=average), dist_sync_on_step=dist_sync_on_step, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, + metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, ) - def test_jaccard_functional(self, reduction, preds, target, sk_metric, num_classes): - average = "macro" if reduction == "elementwise_mean" else None # convert tags + def test_jaccard_functional(self, average, preds, target, sk_metric, num_classes): + # average = "macro" if reduction == "elementwise_mean" else None # convert tags self.run_functional_metric_test( preds, target, metric_functional=jaccard_index, sk_metric=partial(sk_metric, average=average), - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, + metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, ) - def test_jaccard_differentiability(self, reduction, preds, target, sk_metric, num_classes): + def test_jaccard_differentiability(self, average, preds, target, sk_metric, num_classes): self.run_differentiability_test( preds=preds, target=target, metric_module=JaccardIndex, metric_functional=jaccard_index, - metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "reduction": reduction}, + metric_args={"num_classes": num_classes, "threshold": THRESHOLD, "average": average}, ) @pytest.mark.parametrize( - ["half_ones", "reduction", "ignore_index", "expected"], + ["half_ones", "average", "ignore_index", "expected"], [ (False, "none", None, Tensor([1, 1, 1])), - (False, "elementwise_mean", None, Tensor([1])), + (False, "macro", None, Tensor([1])), (False, "none", 0, Tensor([1, 1])), (True, "none", None, Tensor([0.5, 0.5, 0.5])), - (True, "elementwise_mean", None, Tensor([0.5])), + (True, "macro", None, Tensor([0.5])), (True, "none", 0, Tensor([2 / 3, 1 / 2])), ], ) -def test_jaccard(half_ones, reduction, ignore_index, expected): +def test_jaccard(half_ones, average, ignore_index, expected): preds = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: @@ -155,9 +155,10 @@ def test_jaccard(half_ones, reduction, ignore_index, expected): jaccard_val = jaccard_index( preds=preds, target=target, + average=average, num_classes=3, ignore_index=ignore_index, - reduction=reduction, + # reduction=reduction, ) assert torch.allclose(jaccard_val, expected, atol=1e-9) @@ -199,10 +200,11 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas jaccard_val = jaccard_index( preds=tensor(pred), target=tensor(target), + average=None, ignore_index=ignore_index, absent_score=absent_score, num_classes=num_classes, - reduction="none", + # reduction="none", ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) @@ -210,7 +212,7 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py @pytest.mark.parametrize( - ["pred", "target", "ignore_index", "num_classes", "reduction", "expected"], + ["pred", "target", "ignore_index", "num_classes", "average", "expected"], [ # Ignoring an index outside of [0, num_classes-1] should have no effect. ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, "none", [1, 1 / 2, 2 / 3]), @@ -221,16 +223,17 @@ def test_jaccard_absent_score(pred, target, ignore_index, absent_score, num_clas ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, "none", [1, 2 / 3]), ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, "none", [1, 1]), # When reducing to mean or sum, the ignored index does not contribute to the output. - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "elementwise_mean", [7 / 12]), - ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), + ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "macro", [7 / 12]), + # ([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, "sum", [7 / 6]), ], ) -def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): +def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, expected): jaccard_val = jaccard_index( preds=tensor(pred), target=tensor(target), + average=average, ignore_index=ignore_index, num_classes=num_classes, - reduction=reduction, + # reduction=reduction, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 4c7377e1f5b..58f2438f671 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -15,7 +15,6 @@ import torch from torch import Tensor -from typing_extensions import Literal from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.jaccard import _jaccard_from_confmat @@ -45,6 +44,18 @@ class JaccardIndex(ConfusionMatrix): Args: num_classes: Number of classes in the dataset. + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. Note that if a given class doesn't occur in the + `preds` or `target`, the value for the class will be ``nan``. + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the range [0, num_classes-1]. By default, no index is ignored, and all classes are used. @@ -53,12 +64,6 @@ class JaccardIndex(ConfusionMatrix): [0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be assigned the `absent_score`. threshold: Threshold value for binary or multi-label probabilities. multilabel: determines if data is multilabel or not. - reduction: a method to reduce metric score over labels: - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'``: no reduction will be applied - kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Example: @@ -78,11 +83,11 @@ class JaccardIndex(ConfusionMatrix): def __init__( self, num_classes: int, + average: Optional[str] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, multilabel: bool = False, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", **kwargs: Dict[str, Any], ) -> None: super().__init__( @@ -92,12 +97,16 @@ def __init__( multilabel=multilabel, **kwargs, ) - self.reduction = reduction + self.average = average self.ignore_index = ignore_index self.absent_score = absent_score def compute(self) -> Tensor: """Computes intersection over union (IoU)""" return _jaccard_from_confmat( - self.confmat, self.num_classes, self.ignore_index, self.absent_score, self.reduction + self.confmat, + self.num_classes, + self.average, + self.ignore_index, + self.absent_score, ) diff --git a/torchmetrics/functional/classification/jaccard.py b/torchmetrics/functional/classification/jaccard.py index a73a9031a94..4f9b9c2400e 100644 --- a/torchmetrics/functional/classification/jaccard.py +++ b/torchmetrics/functional/classification/jaccard.py @@ -15,65 +15,90 @@ import torch from torch import Tensor -from typing_extensions import Literal from torchmetrics.functional.classification.confusion_matrix import _confusion_matrix_update -from torchmetrics.utilities.distributed import reduce def _jaccard_from_confmat( confmat: Tensor, num_classes: int, + average: Optional[str] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: """Computes the intersection over union from confusion matrix. Args: confmat: Confusion matrix without normalization num_classes: Number of classes for a given prediction and target tensor + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. Note that if a given class doesn't occur in the + `preds` or `target`, the value for the class will be ``nan``. + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. - absent_score: score to use for an individual class, if no instances of the class index were present in ``preds`` - AND no instances of the class index were present in ``target``. - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'`` or ``None``: no reduction will be applied + absent_score: score to use for an individual class, if no instances of the class index were present in `pred` + AND no instances of the class index were present in `target`. """ + allowed_average = ["micro", "macro", "weighted", "none", None] + if average not in allowed_average: + raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") # Remove the ignored class index from the scores. if ignore_index is not None and 0 <= ignore_index < num_classes: confmat[ignore_index] = 0.0 - intersection = torch.diag(confmat) - union = confmat.sum(0) + confmat.sum(1) - intersection - - # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. - scores = intersection.float() / union.float() - scores[union == 0] = absent_score - - if ignore_index is not None and 0 <= ignore_index < num_classes: - scores = torch.cat( - [ - scores[:ignore_index], - scores[ignore_index + 1 :], - ] + if average == "none" or average is None: + intersection = torch.diag(confmat) + union = confmat.sum(0) + confmat.sum(1) - intersection + + # If this class is absent in both target AND pred (union == 0), then use the absent_score for this class. + scores = intersection.float() / union.float() + scores[union == 0] = absent_score + + if ignore_index is not None and 0 <= ignore_index < num_classes: + scores = torch.cat( + [ + scores[:ignore_index], + scores[ignore_index + 1 :], + ] + ) + return scores + + if average == "macro": + scores = _jaccard_from_confmat( + confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score ) + return torch.mean(scores) - return reduce(scores, reduction=reduction) + if average == "micro": + intersection = torch.sum(torch.diag(confmat)) + union = torch.sum(torch.sum(confmat, dim=1) + torch.sum(confmat, dim=0) - torch.diag(confmat)) + return intersection.float() / union.float() + + weights = torch.sum(confmat, dim=1).float() / torch.sum(confmat).float() + scores = _jaccard_from_confmat( + confmat, num_classes, average="none", ignore_index=ignore_index, absent_score=absent_score + ) + return torch.sum(weights * scores) def jaccard_index( preds: Tensor, target: Tensor, num_classes: int, + average: Optional[str] = "macro", ignore_index: Optional[int] = None, absent_score: float = 0.0, threshold: float = 0.5, - reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean", ) -> Tensor: r"""Computes `Jaccard index`_ @@ -95,6 +120,18 @@ def jaccard_index( preds: tensor containing predictions from model (probabilities, or labels) with shape ``[N, d1, d2, ...]`` target: tensor containing ground truth labels with shape ``[N, d1, d2, ...]`` num_classes: Specify the number of classes + average: + Defines the reduction that is applied. Should be one of the following: + + - ``'macro'`` [default]: Calculate the metric for each class separately, and average the + metrics across classes (with equal weights for each class). + - ``'micro'``: Calculate the metric globally, across all samples and classes. + - ``'weighted'``: Calculate the metric for each class separately, and average the + metrics across classes, weighting each class by its support (``tp + fn``). + - ``'none'`` or ``None``: Calculate the metric for each class separately, and return + the metric for every class. Note that if a given class doesn't occur in the + `preds` or `target`, the value for the class will be ``nan``. + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute to the returned score, regardless of reduction method. Has no effect if given an int that is not in the @@ -106,15 +143,13 @@ def jaccard_index( [0, 0] for ``preds``, and [0, 2] for ``target``, then class 1 would be assigned the `absent_score`. threshold: Threshold value for binary or multi-label probabilities. - reduction: a method to reduce metric score over labels. - - - ``'elementwise_mean'``: takes the mean (default) - - ``'sum'``: takes the sum - - ``'none'`` or ``None``: no reduction will be applied Return: - IoU score: Tensor containing single value if reduction is - 'elementwise_mean', or number of classes if reduction is 'none' + The shape of the returned tensor depends on the ``average`` parameter + + - If ``average in ['micro', 'macro', 'weighted']``, a one-element tensor will be returned + - If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number + of classes Example: >>> from torchmetrics.functional import jaccard_index @@ -126,4 +161,4 @@ def jaccard_index( """ confmat = _confusion_matrix_update(preds, target, num_classes, threshold) - return _jaccard_from_confmat(confmat, num_classes, ignore_index, absent_score, reduction) + return _jaccard_from_confmat(confmat, num_classes, average, ignore_index, absent_score)