From fef503dec38a0149c9ce3e5fe1a6c435d8153f4b Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 22 Oct 2021 15:17:13 +0800 Subject: [PATCH 1/6] [Feature]: Support plot confusion matrix --- tools/analysis_tools/confusion_matrix.py | 221 +++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tools/analysis_tools/confusion_matrix.py diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py new file mode 100644 index 00000000000..20298f9d8bd --- /dev/null +++ b/tools/analysis_tools/confusion_matrix.py @@ -0,0 +1,221 @@ +import argparse +import os + +import matplotlib.pyplot as plt +import mmcv +import numpy as np +from matplotlib.ticker import MultipleLocator +from mmcv import Config, DictAction +from mmcv.ops import nms + +from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps +from mmdet.datasets import build_dataset + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate confusion matrix from detection results') + parser.add_argument('config', help='test config file path') + parser.add_argument( + 'prediction_path', help='prediction path where test .pkl result') + parser.add_argument( + 'save_dir', help='directory where confusion matrix will be saved') + parser.add_argument( + '--show', action='store_true', help='show confusion matrix') + parser.add_argument( + '--color-theme', + default='plasma', + help='theme of the matrix color map') + parser.add_argument( + '--score-thr', + type=float, + default=0, + help='score threshold to filter detection bboxes') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + args = parser.parse_args() + return args + + +def calculate_confusion_matrix(dataset, results): + num_classes = len(dataset.CLASSES) + confusion_matrix = np.zeros(shape=[num_classes + 1, num_classes + 1]) + assert len(dataset) == len(results) + prog_bar = mmcv.ProgressBar(len(results)) + for idx, per_img_res in enumerate(results): + if isinstance(per_img_res, tuple): + res_bboxes, _ = per_img_res + else: + res_bboxes = per_img_res + ann = dataset.get_ann_info(idx) + gt_bboxes = ann['bboxes'] + labels = ann['labels'] + analyze_per_img_dets(confusion_matrix, gt_bboxes, labels, res_bboxes, + 0.5) + prog_bar.update() + per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] + normalized_matrix = \ + confusion_matrix.astype('float') / per_label_sums * 100 + return normalized_matrix + + +def analyze_per_img_dets(confusion_matrix, + gt_bboxes, + gt_labels, + result, + score_thr=0, + iou_thr=0.5, + nms_iou_thr=None): + """Analyze detection results on each image. + + Args: + confusion_matrix (ndarray): The confusion matrix, + has shape (num_classes + 1, num_classes + 1). + gt_bboxes (ndarray): Ground truth bboxes, has shape (num_gt, 4). + gt_labels (ndarray): Ground truth labels, has shape (num_gt). + result (ndarray): Detection results, has shape + (num_classes, num_bboxes, 5). + score_thr (float|optional): Score threshold to filter bboxes. + Default: 0. + iou_thr (float|optional): IoU threshold to be considered as matched. + Default: 0.5. + nms_iou_thr (float|optional): nms IOU threshold, the detection results + have done nms in the detector, only applied when users want to + change the nms iou threshold. Default: None. + """ + true_positives = np.zeros_like(gt_labels) + for det_label, det_bboxes in enumerate(result): + if nms_iou_thr: + det_bboxes, _ = nms( + det_bboxes[:, :4], + det_bboxes[:, -1], + nms_iou_thr, + score_threshold=score_thr) + ious = bbox_overlaps(det_bboxes[:, :4], gt_bboxes) + for i, det_bbox in enumerate(det_bboxes): + score = det_bbox[4] + det_match = 0 + if score >= score_thr: + for j, gt_label in enumerate(gt_labels): + if ious[i, j] >= iou_thr: + det_match += 1 + if gt_label == det_label: + true_positives[j] += 1 # TP + confusion_matrix[gt_label, det_label] += 1 + if det_match == 0: # BG FP + confusion_matrix[-1, det_label] += 1 + for num_tp, gt_label in zip(true_positives, gt_labels): + if num_tp == 0: # FN + confusion_matrix[gt_label, -1] += 1 + + +def plot_confusion_matrix(confusion_matrix, + labels, + save_dir=None, + show=True, + title='Normalized Confusion Matrix', + color_theme='plasma'): + """Draw confusion matrix with matplotlib. + + Args: + confusion_matrix (ndarray): The confusion matrix. + labels (list[str]): List of class names. + save_dir (str|optional): If set, save the confusion matrix plot to the + given path. Default: None. + show (bool): Whether to show the plot. Default: True. + title (str): Title of the plot. Default: `Normalized Confusion Matrix`. + color_theme (str): Theme of the matrix color map. Default: `plasma`. + """ + num_classes = len(labels) + fig, ax = plt.subplots( + figsize=(0.5 * num_classes, 0.5 * num_classes * 0.8), dpi=300) + cmap = plt.get_cmap(color_theme) + im = ax.imshow(confusion_matrix, cmap=cmap) + plt.colorbar(mappable=im, ax=ax) + + title_font = {'weight': 'bold', 'size': 12} + ax.set_title(title, fontdict=title_font) + label_font = {'size': 7} + plt.ylabel('Ground Truth Label', fontdict=label_font) + plt.xlabel('Prediction Label', fontdict=label_font) + + # draw locator + xmajor_locator = MultipleLocator(1) + xminor_locator = MultipleLocator(0.5) + ax.xaxis.set_major_locator(xmajor_locator) + ax.xaxis.set_minor_locator(xminor_locator) + ymajor_locator = MultipleLocator(1) + yminor_locator = MultipleLocator(0.5) + ax.yaxis.set_major_locator(ymajor_locator) + ax.yaxis.set_minor_locator(yminor_locator) + + # draw grid + ax.grid(True, which='minor', linestyle='-') + + # draw label + ax.set_xticks(np.arange(num_classes)) + ax.set_yticks(np.arange(num_classes)) + ax.set_xticklabels(labels) + ax.set_yticklabels(labels) + + ax.tick_params( + axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) + plt.setp( + ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') + + # draw text + for i in range(num_classes): + for j in range(num_classes): + ax.text( + j, + i, + '%0.1f' % confusion_matrix[i, j], + ha='center', + va='center', + color='w', + size=7) + + ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 + + fig.tight_layout() + if save_dir is not None: + plt.savefig( + os.path.join(save_dir, 'confusion_matrix.png'), format='png') + if show: + plt.show() + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + outputs = mmcv.load(args.prediction_path) + + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + dataset = build_dataset(cfg.data.test) + + normalized_confusion_matrix = calculate_confusion_matrix(dataset, outputs) + plot_confusion_matrix( + normalized_confusion_matrix, + dataset.CLASSES + ('background', ), + save_dir=args.save_dir, + show=args.show) + + +if __name__ == '__main__': + main() From 1cd3a153d6de36cb563a0a86e6a96b1e1fb35f9a Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 22 Oct 2021 15:23:11 +0800 Subject: [PATCH 2/6] [Feature]: Support plot confusion matrix --- tools/analysis_tools/confusion_matrix.py | 27 +++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index 20298f9d8bd..c19adee7319 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -45,7 +45,24 @@ def parse_args(): return args -def calculate_confusion_matrix(dataset, results): +def calculate_confusion_matrix(dataset, + results, + score_thr=0, + nms_iou_thr=None, + tp_iou_thr=0.5): + """Calculate the confusion matrix. + + Args: + dataset (Dataset): Test or val dataset. + results (list[ndarray]): A list of detection results in each image. + score_thr (float|optional): Score threshold to filter bboxes. + Default: 0. + nms_iou_thr (float|optional): nms IOU threshold, the detection results + have done nms in the detector, only applied when users want to + change the nms iou threshold. Default: None. + tp_iou_thr (float|optional): IoU threshold to be considered as matched. + Default: 0.5. + """ num_classes = len(dataset.CLASSES) confusion_matrix = np.zeros(shape=[num_classes + 1, num_classes + 1]) assert len(dataset) == len(results) @@ -59,7 +76,7 @@ def calculate_confusion_matrix(dataset, results): gt_bboxes = ann['bboxes'] labels = ann['labels'] analyze_per_img_dets(confusion_matrix, gt_bboxes, labels, res_bboxes, - 0.5) + score_thr, tp_iou_thr, nms_iou_thr) prog_bar.update() per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] normalized_matrix = \ @@ -72,7 +89,7 @@ def analyze_per_img_dets(confusion_matrix, gt_labels, result, score_thr=0, - iou_thr=0.5, + tp_iou_thr=0.5, nms_iou_thr=None): """Analyze detection results on each image. @@ -85,7 +102,7 @@ def analyze_per_img_dets(confusion_matrix, (num_classes, num_bboxes, 5). score_thr (float|optional): Score threshold to filter bboxes. Default: 0. - iou_thr (float|optional): IoU threshold to be considered as matched. + tp_iou_thr (float|optional): IoU threshold to be considered as matched. Default: 0.5. nms_iou_thr (float|optional): nms IOU threshold, the detection results have done nms in the detector, only applied when users want to @@ -105,7 +122,7 @@ def analyze_per_img_dets(confusion_matrix, det_match = 0 if score >= score_thr: for j, gt_label in enumerate(gt_labels): - if ious[i, j] >= iou_thr: + if ious[i, j] >= tp_iou_thr: det_match += 1 if gt_label == det_label: true_positives[j] += 1 # TP From 89a933553df19353188c79af196240ae39e89cd8 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 22 Oct 2021 15:29:57 +0800 Subject: [PATCH 3/6] add args --- tools/analysis_tools/confusion_matrix.py | 26 +++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index c19adee7319..0ac9b6fb8fa 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -31,6 +31,17 @@ def parse_args(): type=float, default=0, help='score threshold to filter detection bboxes') + parser.add_argument( + '--tp-iou-thr', + type=float, + default=0.5, + help='IoU threshold to be considered as matched') + parser.add_argument( + '--nms-iou-thr', + type=float, + default=None, + help='nms IoU threshold, only applied when users want to change the' + 'nms IoU threshold.') parser.add_argument( '--cfg-options', nargs='+', @@ -57,9 +68,9 @@ def calculate_confusion_matrix(dataset, results (list[ndarray]): A list of detection results in each image. score_thr (float|optional): Score threshold to filter bboxes. Default: 0. - nms_iou_thr (float|optional): nms IOU threshold, the detection results + nms_iou_thr (float|optional): nms IoU threshold, the detection results have done nms in the detector, only applied when users want to - change the nms iou threshold. Default: None. + change the nms IoU threshold. Default: None. tp_iou_thr (float|optional): IoU threshold to be considered as matched. Default: 0.5. """ @@ -100,13 +111,13 @@ def analyze_per_img_dets(confusion_matrix, gt_labels (ndarray): Ground truth labels, has shape (num_gt). result (ndarray): Detection results, has shape (num_classes, num_bboxes, 5). - score_thr (float|optional): Score threshold to filter bboxes. + score_thr (float): Score threshold to filter bboxes. Default: 0. - tp_iou_thr (float|optional): IoU threshold to be considered as matched. + tp_iou_thr (float): IoU threshold to be considered as matched. Default: 0.5. - nms_iou_thr (float|optional): nms IOU threshold, the detection results + nms_iou_thr (float|optional): nms IoU threshold, the detection results have done nms in the detector, only applied when users want to - change the nms iou threshold. Default: None. + change the nms IoU threshold. Default: None. """ true_positives = np.zeros_like(gt_labels) for det_label, det_bboxes in enumerate(result): @@ -226,7 +237,8 @@ def main(): ds_cfg.test_mode = True dataset = build_dataset(cfg.data.test) - normalized_confusion_matrix = calculate_confusion_matrix(dataset, outputs) + normalized_confusion_matrix = calculate_confusion_matrix( + dataset, outputs, args.score_thr, args.nms_iou_thr, args.tp_iou_thr) plot_confusion_matrix( normalized_confusion_matrix, dataset.CLASSES + ('background', ), From 96d76b1964664c5321c52f144e1c727a61489122 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 22 Oct 2021 15:52:36 +0800 Subject: [PATCH 4/6] dpi --- tools/analysis_tools/confusion_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index 0ac9b6fb8fa..d781cc02206 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -164,7 +164,7 @@ def plot_confusion_matrix(confusion_matrix, """ num_classes = len(labels) fig, ax = plt.subplots( - figsize=(0.5 * num_classes, 0.5 * num_classes * 0.8), dpi=300) + figsize=(0.5 * num_classes, 0.5 * num_classes * 0.8), dpi=180) cmap = plt.get_cmap(color_theme) im = ax.imshow(confusion_matrix, cmap=cmap) plt.colorbar(mappable=im, ax=ax) From 1ae048c6bd75965c61769e5ecea0fe2645bc80b6 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 5 Nov 2021 15:52:41 +0800 Subject: [PATCH 5/6] fix type --- tools/analysis_tools/confusion_matrix.py | 31 ++++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index d781cc02206..f6af1750712 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -29,7 +29,7 @@ def parse_args(): parser.add_argument( '--score-thr', type=float, - default=0, + default=0.3, help='score threshold to filter detection bboxes') parser.add_argument( '--tp-iou-thr', @@ -89,10 +89,7 @@ def calculate_confusion_matrix(dataset, analyze_per_img_dets(confusion_matrix, gt_bboxes, labels, res_bboxes, score_thr, tp_iou_thr, nms_iou_thr) prog_bar.update() - per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] - normalized_matrix = \ - confusion_matrix.astype('float') / per_label_sums * 100 - return normalized_matrix + return confusion_matrix def analyze_per_img_dets(confusion_matrix, @@ -162,6 +159,11 @@ def plot_confusion_matrix(confusion_matrix, title (str): Title of the plot. Default: `Normalized Confusion Matrix`. color_theme (str): Theme of the matrix color map. Default: `plasma`. """ + # normalize the confusion matrix + per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] + confusion_matrix = \ + confusion_matrix.astype(np.float32) / per_label_sums * 100 + num_classes = len(labels) fig, ax = plt.subplots( figsize=(0.5 * num_classes, 0.5 * num_classes * 0.8), dpi=180) @@ -199,7 +201,7 @@ def plot_confusion_matrix(confusion_matrix, plt.setp( ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') - # draw text + # draw confution matrix value for i in range(num_classes): for j in range(num_classes): ax.text( @@ -228,7 +230,14 @@ def main(): if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) - outputs = mmcv.load(args.prediction_path) + results = mmcv.load(args.prediction_path) + assert isinstance(results, list) + if isinstance(results[0], list): + pass + elif isinstance(results[0], tuple): + results = [result[0] for result in results] + else: + raise TypeError('invalid type of prediction results') if isinstance(cfg.data.test, dict): cfg.data.test.test_mode = True @@ -237,10 +246,12 @@ def main(): ds_cfg.test_mode = True dataset = build_dataset(cfg.data.test) - normalized_confusion_matrix = calculate_confusion_matrix( - dataset, outputs, args.score_thr, args.nms_iou_thr, args.tp_iou_thr) + confusion_matrix = calculate_confusion_matrix(dataset, results, + args.score_thr, + args.nms_iou_thr, + args.tp_iou_thr) plot_confusion_matrix( - normalized_confusion_matrix, + confusion_matrix, dataset.CLASSES + ('background', ), save_dir=args.save_dir, show=args.show) From 71483cf2112fbbdf7f77293a8fe2e403105527e7 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 5 Nov 2021 20:57:58 +0800 Subject: [PATCH 6/6] add doc --- docs/useful_tools.md | 18 ++++++++++++++++++ tools/analysis_tools/confusion_matrix.py | 4 ++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 474c48b2b9f..838c34c0b11 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -472,3 +472,21 @@ differential_evolution step 489: f(x)= 0.386625 2021-07-19 19:46:40,776 - mmdet - INFO Anchor differential evolution result:[[10, 12], [15, 30], [32, 22], [29, 59], [61, 46], [57, 116], [112, 89], [154, 198], [349, 336]] 2021-07-19 19:46:40,798 - mmdet - INFO Result saved in work_dirs/anchor_optimize_result.json ``` + +## Confution Matrix + +A confusion matrix is a summary of prediction results. + +`tools/analysis_tools/confusion_matrix.py` can analyze the prediction results and plot a confution matrix table. + +First, run `tools/test.py` to save the `.pkl` detection results. + +Then, run + +``` +python tools/analysis_tools/confusion_matrix.py ${CONFIG} ${DETECTION_RESULTS} ${SAVE_DIR} --show +``` + +And you will get a confution matrix like this: + +![confution_matrix_example](https://user-images.githubusercontent.com/12907710/140513068-994cdbf4-3a4a-48f0-8fd8-2830d93fd963.png) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index f6af1750712..71e4eb0d9fe 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -173,7 +173,7 @@ def plot_confusion_matrix(confusion_matrix, title_font = {'weight': 'bold', 'size': 12} ax.set_title(title, fontdict=title_font) - label_font = {'size': 7} + label_font = {'size': 10} plt.ylabel('Ground Truth Label', fontdict=label_font) plt.xlabel('Prediction Label', fontdict=label_font) @@ -207,7 +207,7 @@ def plot_confusion_matrix(confusion_matrix, ax.text( j, i, - '%0.1f' % confusion_matrix[i, j], + '{}%'.format(int(confusion_matrix[i, j])), ha='center', va='center', color='w',