Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mIoU calculatiton range #471

Merged
merged 3 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mmseg/core/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def intersect_and_union(pred_label,

intersect = pred_label[pred_label == label]
area_intersect = torch.histc(
intersect.float(), bins=(num_classes), min=0, max=num_classes)
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_pred_label = torch.histc(
pred_label.float(), bins=(num_classes), min=0, max=num_classes)
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_label = torch.histc(
label.float(), bins=(num_classes), min=0, max=num_classes)
label.float(), bins=(num_classes), min=0, max=num_classes - 1)
area_union = area_pred_label + area_label - area_intersect
return area_intersect, area_union, area_pred_label, area_label

Expand Down
23 changes: 20 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,36 @@ def test_metrics():
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)

# Test the availability of arg: ignore_index.
label[:, 2, 5:10] = ignore_index

# Test the correctness of the implementation of mIoU calculation.
all_acc, acc, iou = eval_metrics(
results, label, num_classes, ignore_index, metrics='mIoU')
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)

# Test the correctness of the implementation of mDice calculation.
all_acc, acc, dice = eval_metrics(
results, label, num_classes, ignore_index, metrics='mDice')
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(dice, dice_l)

# Test the correctness of the implementation of joint calculation.
all_acc, acc, iou, dice = eval_metrics(
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
assert np.allclose(dice, dice_l)

# Test the correctness of calculation when arg: num_classes is larger
# than the maximum value of input maps.
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = eval_metrics(
Expand Down Expand Up @@ -121,6 +127,17 @@ def test_metrics():
assert dice[-1] == -1
assert iou[-1] == -1

# Test the bug which is caused by torch.histc.
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
# When the arg:bins is set to be same as arg:max,
# some channels of mIoU may be nan.
results = np.array([np.repeat(31, 59)])
label = np.array([np.arange(59)])
num_classes = 59
all_acc, acc, iou = eval_metrics(
results, label, num_classes, ignore_index=255, metrics='mIoU')
assert np.sum(np.isnan(iou)) == 0
clownrat6 marked this conversation as resolved.
Show resolved Hide resolved


def test_mean_iou():
pred_size = (10, 30, 30)
Expand Down Expand Up @@ -182,7 +199,7 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
filenames.append(filename)
return filenames

pred_size = (10, 512, 1024)
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
Expand Down