Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ignore_index in the computation of IoU #328

Merged
merged 12 commits into from
Jul 29, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))


## [0.4.1] - 2021-07-05

### Changed
Expand Down
4 changes: 2 additions & 2 deletions tests/classification/test_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_iou_differentiability(self, reduction, preds, target, sk_metric, num_cl
pytest.param(False, 'none', 0, Tensor([1, 1])),
pytest.param(True, 'none', None, Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', None, Tensor([0.5])),
pytest.param(True, 'none', 0, Tensor([0.5, 0.5])),
pytest.param(True, 'none', 0, Tensor([2 / 3, 1 / 2])),
])
def test_iou(half_ones, reduction, ignore_index, expected):
preds = (torch.arange(120) % 3).view(-1, 1)
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
# Ignoring a valid index drops only that index from the result.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1]),
# When reducing to mean or sum, the ignored index does not contribute to the output.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]),
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]),
Expand Down
7 changes: 6 additions & 1 deletion torchmetrics/functional/classification/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,24 @@ def _iou_from_confmat(
absent_score: float = 0.0,
reduction: str = 'elementwise_mean',
) -> Tensor:

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
confmat[ignore_index] = 0.

intersection = torch.diag(confmat)
union = confmat.sum(0) + confmat.sum(1) - intersection

# If this class is absent in both target AND pred (union == 0), then use the absent_score for this class.
scores = intersection.float() / union.float()
scores[union == 0] = absent_score

# Remove the ignored class index from the scores.
if ignore_index is not None and 0 <= ignore_index < num_classes:
scores = torch.cat([
scores[:ignore_index],
scores[ignore_index + 1:],
])

return reduce(scores, reduction=reduction)


Expand Down