Skip to content

Commit

Permalink
Fix IoU score for classes not present in target or pred
Browse files Browse the repository at this point in the history
Fixes #3097

- Allow configurable not_present_score for IoU for classes
  not present in target or pred. Defaults to 1.0.
- Also allow passing `num_classes` parameter through from iou
  metric class down to its underlying functional iou
  call.
  • Loading branch information
abrahambotros committed Sep 1, 2020
1 parent a5288fe commit e75f32f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed batch size auto-scaling feature to set the new value on the correct model attribute ([#3043](https://github.com/PyTorchLightning/pytorch-lightning/pull/3043))
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))
- Fixed IoU score for classes not present in target or pred ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098))

## [0.8.5] - 2020-07-09

Expand Down
16 changes: 15 additions & 1 deletion pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,11 +797,16 @@ class IoU(TensorMetric):

def __init__(
self,
not_present_score: float = 1.0,
num_classes: Optional[int] = None,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'
):
"""
Args:
not_present_score: score to use for a class, if no instance of that class was present in either pred or
target
num_classes: Optionally specify the number of classes
remove_bg: Flag to state whether a background class has been included
within input parameters. If true, will remove background class. If
false, return IoU over all classes.
Expand All @@ -814,6 +819,8 @@ def __init__(
- sum: add elements
"""
super().__init__(name='iou')
self.not_present_score = not_present_score
self.num_classes = num_classes
self.remove_bg = remove_bg
self.reduction = reduction

Expand All @@ -822,4 +829,11 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor,
"""
Actual metric calculation.
"""
return iou(y_pred, y_true, remove_bg=self.remove_bg, reduction=self.reduction)
return iou(
pred=y_pred,
target=y_true,
not_present_score=self.not_present_score,
num_classes=self.num_classes,
remove_bg=self.remove_bg,
reduction=self.reduction,
)
34 changes: 25 additions & 9 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,16 +963,18 @@ def dice_score(
def iou(
pred: torch.Tensor,
target: torch.Tensor,
not_present_score: float = 1.0,
num_classes: Optional[int] = None,
remove_bg: bool = False,
reduction: str = 'elementwise_mean'
reduction: str = 'elementwise_mean',
) -> torch.Tensor:
"""
Intersection over union, or Jaccard index calculation.
Args:
pred: Tensor containing predictions
target: Tensor containing targets
not_present_score: score to use for a class, if no instance of that class was present in either pred or target
num_classes: Optionally specify the number of classes
remove_bg: Flag to state whether a background class has been included
within input parameters. If true, will remove background class. If
Expand All @@ -998,12 +1000,26 @@ def iou(
tensor(0.4914)
"""
num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes)

# Determine minimum class index we will be evaluating. If using the background, then this is 0; otherwise, if
# removing background, use 1.
min_class_idx = 1 if remove_bg else 0

tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes)
if remove_bg:
tps = tps[1:]
fps = fps[1:]
fns = fns[1:]
denom = fps + fns + tps
denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom)
iou = tps / denom
return reduce(iou, reduction=reduction)

scores = torch.zeros(num_classes - min_class_idx, device=pred.device, dtype=torch.float32)
for class_idx in range(min_class_idx, num_classes):
# If this class is not present in either the pred or the target, then use the not_present_score for this class.
if not (target == class_idx).any() and not (pred == class_idx).any():
scores[class_idx - min_class_idx] = not_present_score
continue

tp = tps[class_idx]
fp = fps[class_idx]
fn = fns[class_idx]
denom = tp + fp + fn
score = tp.to(torch.float) / denom
scores[class_idx - min_class_idx] = score

return reduce(scores, reduction=reduction)
48 changes: 47 additions & 1 deletion tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,55 @@ def test_iou(half_ones, reduction, remove_bg, expected):
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction)
iou_val = iou(
pred=pred,
target=target,
remove_bg=remove_bg,
reduction=reduction,
)
assert torch.allclose(iou_val, expected, atol=1e-9)


@pytest.mark.parametrize(['pred', 'target', 'not_present_score', 'num_classes', 'remove_bg', 'expected'], [
# Note that -1 is used as the not_present_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
# 2 classes, class 0 is correct everywhere, class 1 is not present.
pytest.param([0], [0], -1., 2, False, [1., -1.]),
pytest.param([0, 0], [0, 0], -1., 2, False, [1., -1.]),
# not_present_score not applied if only class 0 is present and it's the only class.
pytest.param([0], [0], -1., 1, False, [1.]),
# 2 classes, class 1 is correct everywhere, class 0 is not present.
pytest.param([1], [1], -1., 2, False, [-1., 1.]),
pytest.param([1, 1], [1, 1], -1., 2, False, [-1., 1.]),
# When background removed, class 0 does not get a score (not even the not_present_score).
pytest.param([1], [1], -1., 2, True, [1.0]),
# 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get not_present_score.
pytest.param([0, 2], [0, 2], -1., 3, False, [1., -1., 1.]),
pytest.param([2, 0], [2, 0], -1., 3, False, [1., -1., 1.]),
# 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get not_present_score.
pytest.param([0, 1], [0, 1], -1., 3, False, [1., 1., -1.]),
pytest.param([1, 0], [1, 0], -1., 3, False, [1., 1., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get not_present_score), class
# 2 is not present.
pytest.param([0, 1], [0, 0], -1., 3, False, [0.5, 0., -1.]),
# 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get not_present_score), class
# 2 is not present.
pytest.param([0, 0], [0, 1], -1., 3, False, [0.5, 0., -1.]),
# Sanity checks with not_present_score of 1.0.
pytest.param([0, 2], [0, 2], 1.0, 3, False, [1., 1., 1.]),
pytest.param([0, 2], [0, 2], 1.0, 3, True, [1., 1.]),
])
def test_iou_not_present_score(pred, target, not_present_score, num_classes, remove_bg, expected):
iou_val = iou(
pred=torch.tensor(pred),
target=torch.tensor(target),
not_present_score=not_present_score,
num_classes=num_classes,
remove_bg=remove_bg,
reduction='none',
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py

0 comments on commit e75f32f

Please sign in to comment.