Skip to content

Commit

Permalink
fix ignore_index in the computation of IoU (#328)
Browse files Browse the repository at this point in the history
* fix ignore_index

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ignore_index in iou

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changelog

Co-authored-by: CSautier <corentin.sautier@valeo.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
5 people authored Jul 29, 2021
1 parent 85ebc3b commit c9d36b2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))


## [0.4.1] - 2021-07-05

### Changed
Expand Down
4 changes: 2 additions & 2 deletions tests/classification/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_iou_differentiability(self, reduction, preds, target, sk_metric, num_cl
pytest.param(False, 'none', 0, Tensor([1, 1])),
pytest.param(True, 'none', None, Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, Tensor([0.5])),
pytest.param(True, 'none', 0, Tensor([0.5, 0.5])),
pytest.param(True, 'none', 0, Tensor([2 / 3, 1 / 2])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
preds = (torch.arange(120) % 3).view(-1, 1)
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
Expand Down
7 changes: 6 additions & 1 deletion torchmetrics/functional/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,24 @@ def _iou_from_confmat(
absent_score: float = 0.0,
reduction: str = 'elementwise_mean',
) -> Tensor:

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.

intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat([
scores[:ignore_index],
scores[ignore_index + 1:],
])

return reduce(scores, reduction=reduction)


Expand Down

0 comments on commit c9d36b2

Please sign in to comment.