Skip to content

Commit

Permalink
Add a test case in test_metrics() of tests/test_metrics.py to test th…
Browse files Browse the repository at this point in the history
…e bug caused by torch.histc;
  • Loading branch information
sennnnn committed Apr 13, 2021
1 parent 2cffeaf commit 50ce86d
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,36 @@ def test_metrics():
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
label = np.random.randint(0, num_classes, size=pred_size)

# Test the availability of arg: ignore_index.
label[:, 2, 5:10] = ignore_index

# Test the correctness of the implementation of mIoU calculation.
all_acc, acc, iou = eval_metrics(
results, label, num_classes, ignore_index, metrics='mIoU')
all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)

# Test the correctness of the implementation of mDice calculation.
all_acc, acc, dice = eval_metrics(
results, label, num_classes, ignore_index, metrics='mDice')
all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes,
ignore_index)
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(dice, dice_l)

# Test the correctness of the implementation of joint calculation.
all_acc, acc, iou, dice = eval_metrics(
results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice'])
assert all_acc == all_acc_l
assert np.allclose(acc, acc_l)
assert np.allclose(iou, iou_l)
assert np.allclose(dice, dice_l)

# Test the correctness of calculation when arg: num_classes is larger
# than the maximum value of input maps.
results = np.random.randint(0, 5, size=pred_size)
label = np.random.randint(0, 4, size=pred_size)
all_acc, acc, iou = eval_metrics(
Expand Down Expand Up @@ -121,6 +127,17 @@ def test_metrics():
assert dice[-1] == -1
assert iou[-1] == -1

# Test the bug which is caused by torch.histc.
# torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html
# When the arg:bins is set to be same as arg:max,
# some channels of mIoU may be nan.
results = np.array([np.repeat(31, 59)])
label = np.array([np.arange(59)])
num_classes = 59
all_acc, acc, iou = eval_metrics(
results, label, num_classes, ignore_index=255, metrics='mIoU')
assert np.sum(np.isnan(iou)) == 0


def test_mean_iou():
pred_size = (10, 30, 30)
Expand Down Expand Up @@ -182,7 +199,7 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str):
filenames.append(filename)
return filenames

pred_size = (10, 512, 1024)
pred_size = (10, 30, 30)
num_classes = 19
ignore_index = 255
results = np.random.randint(0, num_classes, size=pred_size)
Expand Down

0 comments on commit 50ce86d

Please sign in to comment.