diff --git a/mmseg/core/evaluation/__init__.py b/mmseg/core/evaluation/__init__.py index f169d1bf1b..c58d926f06 100644 --- a/mmseg/core/evaluation/__init__.py +++ b/mmseg/core/evaluation/__init__.py @@ -1,7 +1,8 @@ from .class_names import get_classes, get_palette from .eval_hooks import DistEvalHook, EvalHook -from .mean_iou import mean_iou +from .metrics import eval_metrics, mean_dice, mean_iou __all__ = [ - 'EvalHook', 'DistEvalHook', 'mean_iou', 'get_classes', 'get_palette' + 'EvalHook', 'DistEvalHook', 'mean_dice', 'mean_iou', 'eval_metrics', + 'get_classes', 'get_palette' ] diff --git a/mmseg/core/evaluation/mean_iou.py b/mmseg/core/evaluation/mean_iou.py deleted file mode 100644 index 301cfd04fb..0000000000 --- a/mmseg/core/evaluation/mean_iou.py +++ /dev/null @@ -1,74 +0,0 @@ -import numpy as np - - -def intersect_and_union(pred_label, label, num_classes, ignore_index): - """Calculate intersection and Union. - - Args: - pred_label (ndarray): Prediction segmentation map - label (ndarray): Ground truth segmentation map - num_classes (int): Number of categories - ignore_index (int): Index that will be ignored in evaluation. - - Returns: - ndarray: The intersection of prediction and ground truth histogram - on all classes - ndarray: The union of prediction and ground truth histogram on all - classes - ndarray: The prediction histogram on all classes. - ndarray: The ground truth histogram on all classes. - """ - - mask = (label != ignore_index) - pred_label = pred_label[mask] - label = label[mask] - - intersect = pred_label[pred_label == label] - area_intersect, _ = np.histogram( - intersect, bins=np.arange(num_classes + 1)) - area_pred_label, _ = np.histogram( - pred_label, bins=np.arange(num_classes + 1)) - area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) - area_union = area_pred_label + area_label - area_intersect - - return area_intersect, area_union, area_pred_label, area_label - - -def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): - """Calculate Intersection and Union (IoU) - - Args: - results (list[ndarray]): List of prediction segmentation maps - gt_seg_maps (list[ndarray]): list of ground truth segmentation maps - num_classes (int): Number of categories - ignore_index (int): Index that will be ignored in evaluation. - nan_to_num (int, optional): If specified, NaN values will be replaced - by the numbers defined by the user. Default: None. - - Returns: - float: Overall accuracy on all images. - ndarray: Per category accuracy, shape (num_classes, ) - ndarray: Per category IoU, shape (num_classes, ) - """ - - num_imgs = len(results) - assert len(gt_seg_maps) == num_imgs - total_area_intersect = np.zeros((num_classes, ), dtype=np.float) - total_area_union = np.zeros((num_classes, ), dtype=np.float) - total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) - total_area_label = np.zeros((num_classes, ), dtype=np.float) - for i in range(num_imgs): - area_intersect, area_union, area_pred_label, area_label = \ - intersect_and_union(results[i], gt_seg_maps[i], num_classes, - ignore_index=ignore_index) - total_area_intersect += area_intersect - total_area_union += area_union - total_area_pred_label += area_pred_label - total_area_label += area_label - all_acc = total_area_intersect.sum() / total_area_label.sum() - acc = total_area_intersect / total_area_label - iou = total_area_intersect / total_area_union - if nan_to_num is not None: - return all_acc, np.nan_to_num(acc, nan=nan_to_num), \ - np.nan_to_num(iou, nan=nan_to_num) - return all_acc, acc, iou diff --git a/mmseg/core/evaluation/metrics.py b/mmseg/core/evaluation/metrics.py new file mode 100644 index 0000000000..45c62b1641 --- /dev/null +++ b/mmseg/core/evaluation/metrics.py @@ -0,0 +1,176 @@ +import numpy as np + + +def intersect_and_union(pred_label, label, num_classes, ignore_index): + """Calculate intersection and Union. + + Args: + pred_label (ndarray): Prediction segmentation map + label (ndarray): Ground truth segmentation map + num_classes (int): Number of categories + ignore_index (int): Index that will be ignored in evaluation. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes + ndarray: The union of prediction and ground truth histogram on all + classes + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + intersect = pred_label[pred_label == label] + area_intersect, _ = np.histogram( + intersect, bins=np.arange(num_classes + 1)) + area_pred_label, _ = np.histogram( + pred_label, bins=np.arange(num_classes + 1)) + area_label, _ = np.histogram(label, bins=np.arange(num_classes + 1)) + area_union = area_pred_label + area_label - area_intersect + + return area_intersect, area_union, area_pred_label, area_label + + +def total_intersect_and_union(results, gt_seg_maps, num_classes, ignore_index): + """Calculate Total Intersection and Union. + + Args: + results (list[ndarray]): List of prediction segmentation maps + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps + num_classes (int): Number of categories + ignore_index (int): Index that will be ignored in evaluation. + + Returns: + ndarray: The intersection of prediction and ground truth histogram + on all classes + ndarray: The union of prediction and ground truth histogram on all + classes + ndarray: The prediction histogram on all classes. + ndarray: The ground truth histogram on all classes. + """ + + num_imgs = len(results) + assert len(gt_seg_maps) == num_imgs + total_area_intersect = np.zeros((num_classes, ), dtype=np.float) + total_area_union = np.zeros((num_classes, ), dtype=np.float) + total_area_pred_label = np.zeros((num_classes, ), dtype=np.float) + total_area_label = np.zeros((num_classes, ), dtype=np.float) + for i in range(num_imgs): + area_intersect, area_union, area_pred_label, area_label = \ + intersect_and_union(results[i], gt_seg_maps[i], num_classes, + ignore_index=ignore_index) + total_area_intersect += area_intersect + total_area_union += area_union + total_area_pred_label += area_pred_label + total_area_label += area_label + return total_area_intersect, total_area_union, \ + total_area_pred_label, total_area_label + + +def mean_iou(results, gt_seg_maps, num_classes, ignore_index, nan_to_num=None): + """Calculate Mean Intersection and Union (mIoU) + + Args: + results (list[ndarray]): List of prediction segmentation maps + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps + num_classes (int): Number of categories + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ) + ndarray: Per category IoU, shape (num_classes, ) + """ + + all_acc, acc, iou = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mIoU'], + nan_to_num=nan_to_num) + return all_acc, acc, iou + + +def mean_dice(results, + gt_seg_maps, + num_classes, + ignore_index, + nan_to_num=None): + """Calculate Mean Dice (mDice) + + Args: + results (list[ndarray]): List of prediction segmentation maps + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps + num_classes (int): Number of categories + ignore_index (int): Index that will be ignored in evaluation. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ) + ndarray: Per category dice, shape (num_classes, ) + """ + + all_acc, acc, dice = eval_metrics( + results=results, + gt_seg_maps=gt_seg_maps, + num_classes=num_classes, + ignore_index=ignore_index, + metrics=['mDice'], + nan_to_num=nan_to_num) + return all_acc, acc, dice + + +def eval_metrics(results, + gt_seg_maps, + num_classes, + ignore_index, + metrics=['mIoU'], + nan_to_num=None): + """Calculate evaluation metrics + Args: + results (list[ndarray]): List of prediction segmentation maps + gt_seg_maps (list[ndarray]): list of ground truth segmentation maps + num_classes (int): Number of categories + ignore_index (int): Index that will be ignored in evaluation. + metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. + nan_to_num (int, optional): If specified, NaN values will be replaced + by the numbers defined by the user. Default: None. + Returns: + float: Overall accuracy on all images. + ndarray: Per category accuracy, shape (num_classes, ) + ndarray: Per category evalution metrics, shape (num_classes, ) + """ + + if isinstance(metrics, str): + metrics = [metrics] + allowed_metrics = ['mIoU', 'mDice'] + if not set(metrics).issubset(set(allowed_metrics)): + raise KeyError('metrics {} is not supported'.format(metrics)) + total_area_intersect, total_area_union, total_area_pred_label, \ + total_area_label = total_intersect_and_union(results, gt_seg_maps, + num_classes, + ignore_index=ignore_index) + all_acc = total_area_intersect.sum() / total_area_label.sum() + acc = total_area_intersect / total_area_label + ret_metrics = [all_acc, acc] + for metric in metrics: + if metric == 'mIoU': + iou = total_area_intersect / total_area_union + ret_metrics.append(iou) + elif metric == 'mDice': + dice = 2 * total_area_intersect / ( + total_area_pred_label + total_area_label) + ret_metrics.append(dice) + if nan_to_num is not None: + ret_metrics = [ + np.nan_to_num(metric, nan=nan_to_num) for metric in ret_metrics + ] + return ret_metrics diff --git a/mmseg/datasets/custom.py b/mmseg/datasets/custom.py index 7e42d6622c..4e7e30e91c 100644 --- a/mmseg/datasets/custom.py +++ b/mmseg/datasets/custom.py @@ -4,9 +4,10 @@ import mmcv import numpy as np from mmcv.utils import print_log +from terminaltables import AsciiTable from torch.utils.data import Dataset -from mmseg.core import mean_iou +from mmseg.core import eval_metrics from mmseg.utils import get_root_logger from .builder import DATASETS from .pipelines import Compose @@ -14,9 +15,8 @@ @DATASETS.register_module() class CustomDataset(Dataset): - """Custom dataset for semantic segmentation. - - An example of file structure is as followed. + """Custom dataset for semantic segmentation. An example of file structure + is as followed. .. code-block:: none @@ -315,7 +315,8 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): Args: results (list): Testing results of the dataset. - metric (str | list[str]): Metrics to be evaluated. + metric (str | list[str]): Metrics to be evaluated. 'mIoU' and + 'mDice' are supported. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. @@ -323,13 +324,11 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): dict[str, float]: Default metrics. """ - if not isinstance(metric, str): - assert len(metric) == 1 - metric = metric[0] - allowed_metrics = ['mIoU'] - if metric not in allowed_metrics: + if isinstance(metric, str): + metric = [metric] + allowed_metrics = ['mIoU', 'mDice'] + if not set(metric).issubset(set(allowed_metrics)): raise KeyError('metric {} is not supported'.format(metric)) - eval_results = {} gt_seg_maps = self.get_gt_seg_maps() if self.CLASSES is None: @@ -337,35 +336,42 @@ def evaluate(self, results, metric='mIoU', logger=None, **kwargs): reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) else: num_classes = len(self.CLASSES) - - all_acc, acc, iou = mean_iou( - results, gt_seg_maps, num_classes, ignore_index=self.ignore_index) - summary_str = '' - summary_str += 'per class results:\n' - - line_format = '{:<15} {:>10} {:>10}\n' - summary_str += line_format.format('Class', 'IoU', 'Acc') + ret_metrics = eval_metrics( + results, + gt_seg_maps, + num_classes, + ignore_index=self.ignore_index, + metrics=metric) + class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] if self.CLASSES is None: class_names = tuple(range(num_classes)) else: class_names = self.CLASSES + ret_metrics_round = [ + np.round(ret_metric * 100, 2) for ret_metric in ret_metrics + ] for i in range(num_classes): - iou_str = '{:.2f}'.format(iou[i] * 100) - acc_str = '{:.2f}'.format(acc[i] * 100) - summary_str += line_format.format(class_names[i], iou_str, acc_str) - summary_str += 'Summary:\n' - line_format = '{:<15} {:>10} {:>10} {:>10}\n' - summary_str += line_format.format('Scope', 'mIoU', 'mAcc', 'aAcc') - - iou_str = '{:.2f}'.format(np.nanmean(iou) * 100) - acc_str = '{:.2f}'.format(np.nanmean(acc) * 100) - all_acc_str = '{:.2f}'.format(all_acc * 100) - summary_str += line_format.format('global', iou_str, acc_str, - all_acc_str) - print_log(summary_str, logger) - - eval_results['mIoU'] = np.nanmean(iou) - eval_results['mAcc'] = np.nanmean(acc) - eval_results['aAcc'] = all_acc - + class_table_data.append([class_names[i]] + + [m[i] for m in ret_metrics_round[2:]] + + [ret_metrics_round[1][i]]) + summary_table_data = [['Scope'] + + ['m' + head + for head in class_table_data[0][1:]] + ['aAcc']] + ret_metrics_mean = [ + np.round(np.nanmean(ret_metric) * 100, 2) + for ret_metric in ret_metrics + ] + summary_table_data.append(['global'] + ret_metrics_mean[2:] + + [ret_metrics_mean[1]] + + [ret_metrics_mean[0]]) + print_log('per class results:', logger) + table = AsciiTable(class_table_data) + print_log('\n' + table.table, logger=logger) + print_log('Summary:', logger) + table = AsciiTable(summary_table_data) + print_log('\n' + table.table, logger=logger) + + for i in range(1, len(summary_table_data[0])): + eval_results[summary_table_data[0] + [i]] = summary_table_data[1][i] / 100.0 return eval_results diff --git a/requirements/runtime.txt b/requirements/runtime.txt index db5d81e01e..a8347b9c0c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,2 +1,3 @@ matplotlib numpy +terminaltables diff --git a/setup.cfg b/setup.cfg index a5fb07d401..708fb4ce33 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,torch +known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,pytest,scipy,terminaltables,torch no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tests/test_data/test_dataset.py b/tests/test_data/test_dataset.py index e933c200cc..2e19c30f08 100644 --- a/tests/test_data/test_dataset.py +++ b/tests/test_data/test_dataset.py @@ -159,17 +159,45 @@ def test_custom_dataset(): for gt_seg_map in gt_seg_maps: h, w = gt_seg_map.shape pseudo_results.append(np.random.randint(low=0, high=7, size=(h, w))) - eval_results = train_dataset.evaluate(pseudo_results) + eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU') assert isinstance(eval_results, dict) assert 'mIoU' in eval_results assert 'mAcc' in eval_results assert 'aAcc' in eval_results + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + assert isinstance(eval_results, dict) + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mDice', 'mIoU']) + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + # evaluation with CLASSES train_dataset.CLASSES = tuple(['a'] * 7) - eval_results = train_dataset.evaluate(pseudo_results) + eval_results = train_dataset.evaluate(pseudo_results, metric='mIoU') + assert isinstance(eval_results, dict) + assert 'mIoU' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate(pseudo_results, metric='mDice') + assert isinstance(eval_results, dict) + assert 'mDice' in eval_results + assert 'mAcc' in eval_results + assert 'aAcc' in eval_results + + eval_results = train_dataset.evaluate( + pseudo_results, metric=['mIoU', 'mDice']) assert isinstance(eval_results, dict) assert 'mIoU' in eval_results + assert 'mDice' in eval_results assert 'mAcc' in eval_results assert 'aAcc' in eval_results diff --git a/tests/test_mean_iou.py b/tests/test_mean_iou.py deleted file mode 100644 index 74a2b78617..0000000000 --- a/tests/test_mean_iou.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np - -from mmseg.core.evaluation import mean_iou - - -def get_confusion_matrix(pred_label, label, num_classes, ignore_index): - """Intersection over Union - Args: - pred_label (np.ndarray): 2D predict map - label (np.ndarray): label 2D label map - num_classes (int): number of categories - ignore_index (int): index ignore in evaluation - """ - - mask = (label != ignore_index) - pred_label = pred_label[mask] - label = label[mask] - - n = num_classes - inds = n * label + pred_label - - mat = np.bincount(inds, minlength=n**2).reshape(n, n) - - return mat - - -# This func is deprecated since it's not memory efficient -def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index): - num_imgs = len(results) - assert len(gt_seg_maps) == num_imgs - total_mat = np.zeros((num_classes, num_classes), dtype=np.float) - for i in range(num_imgs): - mat = get_confusion_matrix( - results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) - total_mat += mat - all_acc = np.diag(total_mat).sum() / total_mat.sum() - acc = np.diag(total_mat) / total_mat.sum(axis=1) - iou = np.diag(total_mat) / ( - total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat)) - - return all_acc, acc, iou - - -def test_mean_iou(): - pred_size = (10, 30, 30) - num_classes = 19 - ignore_index = 255 - results = np.random.randint(0, num_classes, size=pred_size) - label = np.random.randint(0, num_classes, size=pred_size) - label[:, 2, 5:10] = ignore_index - all_acc, acc, iou = mean_iou(results, label, num_classes, ignore_index) - all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, - ignore_index) - assert all_acc == all_acc_l - assert np.allclose(acc, acc_l) - assert np.allclose(iou, iou_l) - - results = np.random.randint(0, 5, size=pred_size) - label = np.random.randint(0, 4, size=pred_size) - all_acc, acc, iou = mean_iou( - results, label, num_classes, ignore_index=255, nan_to_num=-1) - assert acc[-1] == -1 - assert iou[-1] == -1 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000000..023bbb0a55 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,166 @@ +import numpy as np + +from mmseg.core.evaluation import eval_metrics, mean_dice, mean_iou + + +def get_confusion_matrix(pred_label, label, num_classes, ignore_index): + """Intersection over Union + Args: + pred_label (np.ndarray): 2D predict map + label (np.ndarray): label 2D label map + num_classes (int): number of categories + ignore_index (int): index ignore in evaluation + """ + + mask = (label != ignore_index) + pred_label = pred_label[mask] + label = label[mask] + + n = num_classes + inds = n * label + pred_label + + mat = np.bincount(inds, minlength=n**2).reshape(n, n) + + return mat + + +# This func is deprecated since it's not memory efficient +def legacy_mean_iou(results, gt_seg_maps, num_classes, ignore_index): + num_imgs = len(results) + assert len(gt_seg_maps) == num_imgs + total_mat = np.zeros((num_classes, num_classes), dtype=np.float) + for i in range(num_imgs): + mat = get_confusion_matrix( + results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) + total_mat += mat + all_acc = np.diag(total_mat).sum() / total_mat.sum() + acc = np.diag(total_mat) / total_mat.sum(axis=1) + iou = np.diag(total_mat) / ( + total_mat.sum(axis=1) + total_mat.sum(axis=0) - np.diag(total_mat)) + + return all_acc, acc, iou + + +# This func is deprecated since it's not memory efficient +def legacy_mean_dice(results, gt_seg_maps, num_classes, ignore_index): + num_imgs = len(results) + assert len(gt_seg_maps) == num_imgs + total_mat = np.zeros((num_classes, num_classes), dtype=np.float) + for i in range(num_imgs): + mat = get_confusion_matrix( + results[i], gt_seg_maps[i], num_classes, ignore_index=ignore_index) + total_mat += mat + all_acc = np.diag(total_mat).sum() / total_mat.sum() + acc = np.diag(total_mat) / total_mat.sum(axis=1) + dice = 2 * np.diag(total_mat) / ( + total_mat.sum(axis=1) + total_mat.sum(axis=0)) + + return all_acc, acc, dice + + +def test_metrics(): + pred_size = (10, 30, 30) + num_classes = 19 + ignore_index = 255 + results = np.random.randint(0, num_classes, size=pred_size) + label = np.random.randint(0, num_classes, size=pred_size) + label[:, 2, 5:10] = ignore_index + all_acc, acc, iou = eval_metrics( + results, label, num_classes, ignore_index, metrics='mIoU') + all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, + ignore_index) + assert all_acc == all_acc_l + assert np.allclose(acc, acc_l) + assert np.allclose(iou, iou_l) + + all_acc, acc, dice = eval_metrics( + results, label, num_classes, ignore_index, metrics='mDice') + all_acc_l, acc_l, dice_l = legacy_mean_dice(results, label, num_classes, + ignore_index) + assert all_acc == all_acc_l + assert np.allclose(acc, acc_l) + assert np.allclose(dice, dice_l) + + all_acc, acc, iou, dice = eval_metrics( + results, label, num_classes, ignore_index, metrics=['mIoU', 'mDice']) + assert all_acc == all_acc_l + assert np.allclose(acc, acc_l) + assert np.allclose(iou, iou_l) + assert np.allclose(dice, dice_l) + + results = np.random.randint(0, 5, size=pred_size) + label = np.random.randint(0, 4, size=pred_size) + all_acc, acc, iou = eval_metrics( + results, + label, + num_classes, + ignore_index=255, + metrics='mIoU', + nan_to_num=-1) + assert acc[-1] == -1 + assert iou[-1] == -1 + + all_acc, acc, dice = eval_metrics( + results, + label, + num_classes, + ignore_index=255, + metrics='mDice', + nan_to_num=-1) + assert acc[-1] == -1 + assert dice[-1] == -1 + + all_acc, acc, dice, iou = eval_metrics( + results, + label, + num_classes, + ignore_index=255, + metrics=['mDice', 'mIoU'], + nan_to_num=-1) + assert acc[-1] == -1 + assert dice[-1] == -1 + assert iou[-1] == -1 + + +def test_mean_iou(): + pred_size = (10, 30, 30) + num_classes = 19 + ignore_index = 255 + results = np.random.randint(0, num_classes, size=pred_size) + label = np.random.randint(0, num_classes, size=pred_size) + label[:, 2, 5:10] = ignore_index + all_acc, acc, iou = mean_iou(results, label, num_classes, ignore_index) + all_acc_l, acc_l, iou_l = legacy_mean_iou(results, label, num_classes, + ignore_index) + assert all_acc == all_acc_l + assert np.allclose(acc, acc_l) + assert np.allclose(iou, iou_l) + + results = np.random.randint(0, 5, size=pred_size) + label = np.random.randint(0, 4, size=pred_size) + all_acc, acc, iou = mean_iou( + results, label, num_classes, ignore_index=255, nan_to_num=-1) + assert acc[-1] == -1 + assert iou[-1] == -1 + + +def test_mean_dice(): + pred_size = (10, 30, 30) + num_classes = 19 + ignore_index = 255 + results = np.random.randint(0, num_classes, size=pred_size) + label = np.random.randint(0, num_classes, size=pred_size) + label[:, 2, 5:10] = ignore_index + all_acc, acc, iou = mean_dice(results, label, num_classes, ignore_index) + all_acc_l, acc_l, iou_l = legacy_mean_dice(results, label, num_classes, + ignore_index) + assert all_acc == all_acc_l + assert np.allclose(acc, acc_l) + assert np.allclose(iou, iou_l) + + results = np.random.randint(0, 5, size=pred_size) + label = np.random.randint(0, 4, size=pred_size) + all_acc, acc, iou = mean_dice( + results, label, num_classes, ignore_index=255, nan_to_num=-1) + assert acc[-1] == -1 + assert iou[-1] == -1