From 50ce86d84c1f5926343c3242e45319a2481c5c2e Mon Sep 17 00:00:00 2001 From: sennnnn <201730271412@mail.scut.edu.cn> Date: Tue, 13 Apr 2021 15:51:11 +0800 Subject: [PATCH] Add a test case in test_metrics() of tests/test_metrics.py to test the bug caused by torch.histc; --- tests/test_metrics.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 2033617c2a..1f4330df24 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -64,7 +64,11 @@ def test_metrics(): ignore_index = 255 results = np.random.randint(0, num_classes, size=pred_size) label = np.random.randint(0, num_classes, size=pred_size) + + # Test the availability of arg: ignore_index. label[:, 2, 5:10] = ignore_index + + # Test the correctness of the implementation of mIoU calculation. all_acc, acc, iou = eval_metrics( results, label, num_classes, ignore_index, metrics='mIoU') all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, @@ -72,7 +76,7 @@ def test_metrics(): assert all_acc == all_acc_l assert np.allclose(acc, acc_l) assert np.allclose(iou, iou_l) - + # Test the correctness of the implementation of mDice calculation. all_acc, acc, dice = eval_metrics( results, label, num_classes, ignore_index, metrics='mDice') all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes, @@ -80,7 +84,7 @@ def test_metrics(): assert all_acc == all_acc_l assert np.allclose(acc, acc_l) assert np.allclose(dice, dice_l) - + # Test the correctness of the implementation of joint calculation. all_acc, acc, iou, dice = eval_metrics( results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice']) assert all_acc == all_acc_l @@ -88,6 +92,8 @@ def test_metrics(): assert np.allclose(iou, iou_l) assert np.allclose(dice, dice_l) + # Test the correctness of calculation when arg: num_classes is larger + # than the maximum value of input maps. results = np.random.randint(0, 5, size=pred_size) label = np.random.randint(0, 4, size=pred_size) all_acc, acc, iou = eval_metrics( @@ -121,6 +127,17 @@ def test_metrics(): assert dice[-1] == -1 assert iou[-1] == -1 + # Test the bug which is caused by torch.histc. + # torch.histc: https://pytorch.org/docs/stable/generated/torch.histc.html + # When the arg:bins is set to be same as arg:max, + # some channels of mIoU may be nan. + results = np.array([np.repeat(31, 59)]) + label = np.array([np.arange(59)]) + num_classes = 59 + all_acc, acc, iou = eval_metrics( + results, label, num_classes, ignore_index=255, metrics='mIoU') + assert np.sum(np.isnan(iou)) == 0 + def test_mean_iou(): pred_size = (10, 30, 30) @@ -182,7 +199,7 @@ def save_arr(input_arrays: list, title: str, is_image: bool, dir: str): filenames.append(filename) return filenames - pred_size = (10, 512, 1024) + pred_size = (10, 30, 30) num_classes = 19 ignore_index = 255 results = np.random.randint(0, num_classes, size=pred_size)