From ddbfbfcddc5b56de0a1bf960c9acaaf8fc2fc461 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Tue, 9 Nov 2021 17:16:59 +0800 Subject: [PATCH 1/8] first version --- .../api_wrappers/panoptic_evaluation.py | 156 ++++++++++++++++++ mmdet/datasets/custom.py | 24 ++- mmdet/datasets/pipelines/loading.py | 10 +- 3 files changed, 173 insertions(+), 17 deletions(-) create mode 100644 mmdet/datasets/api_wrappers/panoptic_evaluation.py diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py new file mode 100644 index 00000000000..2b380a5bdb7 --- /dev/null +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -0,0 +1,156 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# This file supports `file_client` for panopticapi +import multiprocessing +import os + +import mmcv +import numpy as np + +try: + from panopticapi.evaluation import PQStat, VOID, OFFSET + from panopticapi.utils import rgb2id +except ImportError: + PQStat = None + rgb2id = None + VOID = 0 + OFFSET = 256 * 256 * 256 + + +def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, + categories, file_client): + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + pq_stat = PQStat() + + idx = 0 + for gt_ann, pred_ann in annotation_set: + if idx % 100 == 0: + print('Core: {}, {} from {} images processed'.format( + proc_id, idx, len(annotation_set))) + idx += 1 + img_bytes = file_client.get( + os.path.join(gt_folder, gt_ann['file_name'])) + pan_gt = mmcv.imfrombytes( + img_bytes, flag='color', channel_order='rgb').squeeze() + pan_gt = rgb2id(pan_gt) + + pan_pred = mmcv.imread( + os.path.join(pred_folder, pred_ann['file_name']), + flag='color', + channel_order='rgb').squeeze() + pan_pred = rgb2id(pan_pred) + + gt_segms = {el['id']: el for el in gt_ann['segments_info']} + pred_segms = {el['id']: el for el in pred_ann['segments_info']} + + # predicted segments area calculation + prediction sanity checks + pred_labels_set = set(el['id'] for el in pred_ann['segments_info']) + labels, labels_cnt = np.unique(pan_pred, return_counts=True) + for label, label_cnt in zip(labels, labels_cnt): + if label not in pred_segms: + if label == VOID: + continue + raise KeyError( + 'In the image with ID {} segment with ID {} is ' + 'presented in PNG and not presented in JSON.'.format( + gt_ann['image_id'], label)) + pred_segms[label]['area'] = label_cnt + pred_labels_set.remove(label) + if pred_segms[label]['category_id'] not in categories: + raise KeyError( + 'In the image with ID {} segment with ID {} has ' + 'unknown category_id {}.'.format( + gt_ann['image_id'], label, + pred_segms[label]['category_id'])) + if len(pred_labels_set) != 0: + raise KeyError( + 'In the image with ID {} the following segment IDs {} ' + 'are presented in JSON and not presented in PNG.'.format( + gt_ann['image_id'], list(pred_labels_set))) + + # confusion matrix calculation + pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype( + np.uint64) + gt_pred_map = {} + labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) + for label, intersection in zip(labels, labels_cnt): + gt_id = label // OFFSET + pred_id = label % OFFSET + gt_pred_map[(gt_id, pred_id)] = intersection + + # count all matched pairs + gt_matched = set() + pred_matched = set() + for label_tuple, intersection in gt_pred_map.items(): + gt_label, pred_label = label_tuple + if gt_label not in gt_segms: + continue + if pred_label not in pred_segms: + continue + if gt_segms[gt_label]['iscrowd'] == 1: + continue + if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][ + 'category_id']: + continue + + union = pred_segms[pred_label]['area'] + gt_segms[gt_label][ + 'area'] - intersection - gt_pred_map.get((VOID, pred_label), 0) + iou = intersection / union + if iou > 0.5: + pq_stat[gt_segms[gt_label]['category_id']].tp += 1 + pq_stat[gt_segms[gt_label]['category_id']].iou += iou + gt_matched.add(gt_label) + pred_matched.add(pred_label) + + # count false positives + crowd_labels_dict = {} + for gt_label, gt_info in gt_segms.items(): + if gt_label in gt_matched: + continue + # crowd segments are ignored + if gt_info['iscrowd'] == 1: + crowd_labels_dict[gt_info['category_id']] = gt_label + continue + pq_stat[gt_info['category_id']].fn += 1 + + # count false positives + for pred_label, pred_info in pred_segms.items(): + if pred_label in pred_matched: + continue + # intersection of the segment with VOID + intersection = gt_pred_map.get((VOID, pred_label), 0) + # plus intersection with corresponding CROWD region if it exists + if pred_info['category_id'] in crowd_labels_dict: + intersection += gt_pred_map.get( + (crowd_labels_dict[pred_info['category_id']], pred_label), + 0) + # predicted segment is ignored if more than half of + # the segment correspond to VOID and CROWD regions + if intersection / pred_info['area'] > 0.5: + continue + pq_stat[pred_info['category_id']].fp += 1 + print('Core: {}, all {} images processed'.format(proc_id, + len(annotation_set))) + return pq_stat + + +def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, + categories, file_client): + cpu_num = multiprocessing.cpu_count() + annotations_split = np.array_split(matched_annotations_list, cpu_num) + print('Number of cores: {}, images per core: {}'.format( + cpu_num, len(annotations_split[0]))) + workers = multiprocessing.Pool(processes=cpu_num) + processes = [] + for proc_id, annotation_set in enumerate(annotations_split): + p = workers.apply_async(pq_compute_single_core, + (proc_id, annotation_set, gt_folder, + pred_folder, categories, file_client)) + processes.append(p) + pq_stat = PQStat() + for p in processes: + pq_stat += p.get() + return pq_stat diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 0e82fa396ad..eea1453364d 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -63,7 +63,8 @@ def __init__(self, seg_prefix=None, proposal_file=None, test_mode=False, - filter_empty_gt=True): + filter_empty_gt=True, + file_client_args=dict(backend='disk')): self.ann_file = ann_file self.data_root = data_root self.img_prefix = img_prefix @@ -72,24 +73,31 @@ def __init__(self, self.test_mode = test_mode self.filter_empty_gt = filter_empty_gt self.CLASSES = self.get_classes(classes) + self.file_client = mmcv.FileClient(**file_client_args) # join paths if data_root is specified if self.data_root is not None: if not osp.isabs(self.ann_file): - self.ann_file = osp.join(self.data_root, self.ann_file) + self.ann_file = self.file_client.join_path( + self.data_root, self.ann_file) if not (self.img_prefix is None or osp.isabs(self.img_prefix)): - self.img_prefix = osp.join(self.data_root, self.img_prefix) + self.img_prefix = self.file_client.join_path( + self.data_root, self.img_prefix) if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): - self.seg_prefix = osp.join(self.data_root, self.seg_prefix) + self.seg_prefix = self.file_client.join_path( + self.data_root, self.seg_prefix) if not (self.proposal_file is None or osp.isabs(self.proposal_file)): - self.proposal_file = osp.join(self.data_root, - self.proposal_file) + self.proposal_file = self.file_client.join_path( + self.data_root, self.proposal_file) # load annotations (and proposals) - self.data_infos = self.load_annotations(self.ann_file) + with self.file_client.get_local_path(self.ann_file) as local_path: + self.data_infos = self.load_annotations(local_path) if self.proposal_file is not None: - self.proposals = self.load_proposals(self.proposal_file) + with self.file_client.get_local_path( + self.proposal_file) as local_path: + self.proposals = self.load_proposals(local_path) else: self.proposals = None diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index 1f8ba37e1a5..a7f1974a025 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -231,7 +231,7 @@ def __init__(self, self.with_seg = with_seg self.poly2mask = poly2mask self.file_client_args = file_client_args.copy() - self.file_client = None + self.file_client = mmcv.FileClient(**self.file_client_args) def _load_bboxes(self, results): """Private function to load bounding box annotations. @@ -344,10 +344,6 @@ def _load_semantic_seg(self, results): Returns: dict: The dict contains loaded semantic segmentation annotations. """ - - if self.file_client is None: - self.file_client = mmcv.FileClient(**self.file_client_args) - filename = osp.join(results['seg_prefix'], results['ann_info']['seg_map']) img_bytes = self.file_client.get(filename) @@ -438,12 +434,8 @@ def _load_masks_and_semantic_segs(self, results): dict: The dict contains loaded mask and semantic segmentation annotations. `BitmapMasks` is used for mask annotations. """ - - if self.file_client is None: - self.file_client = mmcv.FileClient(**self.file_client_args) filename = osp.join(results['seg_prefix'], results['ann_info']['seg_map']) - img_bytes = self.file_client.get(filename) pan_png = mmcv.imfrombytes( img_bytes, flag='color', channel_order='rgb').squeeze() From d5c834919b99d21df13533db93979cc02514329a Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Thu, 11 Nov 2021 20:28:50 +0800 Subject: [PATCH 2/8] Replace with our api --- mmdet/datasets/api_wrappers/__init__.py | 3 ++- mmdet/datasets/coco_panoptic.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mmdet/datasets/api_wrappers/__init__.py b/mmdet/datasets/api_wrappers/__init__.py index 9bf807107a4..6dc535a113c 100644 --- a/mmdet/datasets/api_wrappers/__init__.py +++ b/mmdet/datasets/api_wrappers/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .coco_api import COCO, COCOeval +from .panoptic_evaluation import pq_compute_multi_core -__all__ = ['COCO', 'COCOeval'] +__all__ = ['COCO', 'COCOeval', 'pq_compute_multi_core'] diff --git a/mmdet/datasets/coco_panoptic.py b/mmdet/datasets/coco_panoptic.py index 36e95950d97..e2b5c636409 100644 --- a/mmdet/datasets/coco_panoptic.py +++ b/mmdet/datasets/coco_panoptic.py @@ -8,17 +8,16 @@ from mmcv.utils import print_log from terminaltables import AsciiTable -from .api_wrappers import COCO +from .api_wrappers import COCO, pq_compute_multi_core from .builder import DATASETS from .coco import CocoDataset try: import panopticapi - from panopticapi.evaluation import pq_compute_multi_core, VOID + from panopticapi.evaluation import VOID from panopticapi.utils import id2rgb except ImportError: panopticapi = None - pq_compute_multi_core = None id2rgb = None VOID = None @@ -421,7 +420,8 @@ def evaluate_pan_json(self, pred_folder = os.path.join(os.path.dirname(outfile_prefix), 'panoptic') pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, - pred_folder, self.categories) + pred_folder, self.categories, + self.file_client) metrics = [('All', None), ('Things', True), ('Stuff', False)] pq_results = {} From cae49e01e050b38a7ed7f4179748c762ff1dccf9 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Fri, 12 Nov 2021 14:59:32 +0800 Subject: [PATCH 3/8] Add copyright --- mmdet/datasets/api_wrappers/panoptic_evaluation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py index 2b380a5bdb7..8f83cccdc01 100644 --- a/mmdet/datasets/api_wrappers/panoptic_evaluation.py +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. -# This file supports `file_client` for panopticapi + +# Copyright (c) 2018, Alexander Kirillov +# This file supports `file_client` for `panopticapi`, +# the source code is copied from `panopticapi`, +# only the way to read the gt images is modified. import multiprocessing import os From fb96094733836d5b5cc2018ea34e2a3207a68fe9 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Fri, 12 Nov 2021 15:24:27 +0800 Subject: [PATCH 4/8] Move the runtime error to multi_core interface --- mmdet/datasets/api_wrappers/panoptic_evaluation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py index 8f83cccdc01..a83af348595 100644 --- a/mmdet/datasets/api_wrappers/panoptic_evaluation.py +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -22,11 +22,6 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories, file_client): - if PQStat is None: - raise RuntimeError( - 'panopticapi is not installed, please install it by: ' - 'pip install git+https://github.com/cocodataset/' - 'panopticapi.git.') pq_stat = PQStat() idx = 0 @@ -143,6 +138,11 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories, file_client): + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') cpu_num = multiprocessing.cpu_count() annotations_split = np.array_split(matched_annotations_list, cpu_num) print('Number of cores: {}, images per core: {}'.format( From 9197efa381be2ee2cd0fab92fe535b84aa37dfd6 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Fri, 12 Nov 2021 16:27:37 +0800 Subject: [PATCH 5/8] Add docstring --- .../api_wrappers/panoptic_evaluation.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py index a83af348595..427e94b80f9 100644 --- a/mmdet/datasets/api_wrappers/panoptic_evaluation.py +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -3,7 +3,7 @@ # Copyright (c) 2018, Alexander Kirillov # This file supports `file_client` for `panopticapi`, # the source code is copied from `panopticapi`, -# only the way to read the gt images is modified. +# only the way to load the gt images is modified. import multiprocessing import os @@ -22,6 +22,19 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, categories, file_client): + """The single core function to evaluate the metric of Panoptic + Segmentation. + + Same as the function with the same name in `panopticapi`. Only the function + to load the images is changed to use the file client. + + Args: + proc_id (int): The id of the mini process. + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + file_client (object): The file client of the dataset. + """ pq_stat = PQStat() idx = 0 @@ -138,6 +151,19 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, categories, file_client): + """Evaluate the metrics of Panoptic Segmentation with multithreading. + + Same as the function with the same name in `panopticapi`. + + Args: + matched_annotations_list (list): The matched annotation list. Each + element is a tuple of annotations of the same image with the + format (gt_anns, pred_anns). + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + file_client (object): The file client of the dataset. + """ if PQStat is None: raise RuntimeError( 'panopticapi is not installed, please install it by: ' From 18e8d47b1ba280c98fca851e8ce66e2c62a7d5cc Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Mon, 15 Nov 2021 16:04:52 +0800 Subject: [PATCH 6/8] Fix comments --- .../api_wrappers/panoptic_evaluation.py | 35 ++++++++++++++----- mmdet/datasets/custom.py | 13 +++---- mmdet/datasets/pipelines/loading.py | 10 +++++- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py index 427e94b80f9..86e84c858c2 100644 --- a/mmdet/datasets/api_wrappers/panoptic_evaluation.py +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -20,8 +20,12 @@ OFFSET = 256 * 256 * 256 -def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, - categories, file_client): +def pq_compute_single_core(proc_id, + annotation_set, + gt_folder, + pred_folder, + categories, + file_client=None): """The single core function to evaluate the metric of Panoptic Segmentation. @@ -33,8 +37,13 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, gt_folder (str): The path of the ground truth images. pred_folder (str): The path of the prediction images. categories (str): The categories of the dataset. - file_client (object): The file client of the dataset. + file_client (object): The file client of the dataset. If None, + the backend will be set to `disk`. """ + if file_client is None: + file_client_args = dict(backend='disk') + file_client = mmcv.FileClient(**file_client_args) + pq_stat = PQStat() idx = 0 @@ -45,14 +54,13 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, idx += 1 img_bytes = file_client.get( os.path.join(gt_folder, gt_ann['file_name'])) - pan_gt = mmcv.imfrombytes( - img_bytes, flag='color', channel_order='rgb').squeeze() + pan_gt = mmcv.imfrombytes(img_bytes, flag='color', channel_order='rgb') pan_gt = rgb2id(pan_gt) pan_pred = mmcv.imread( os.path.join(pred_folder, pred_ann['file_name']), flag='color', - channel_order='rgb').squeeze() + channel_order='rgb') pan_pred = rgb2id(pan_pred) gt_segms = {el['id']: el for el in gt_ann['segments_info']} @@ -149,8 +157,11 @@ def pq_compute_single_core(proc_id, annotation_set, gt_folder, pred_folder, return pq_stat -def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, - categories, file_client): +def pq_compute_multi_core(matched_annotations_list, + gt_folder, + pred_folder, + categories, + file_client=None): """Evaluate the metrics of Panoptic Segmentation with multithreading. Same as the function with the same name in `panopticapi`. @@ -162,13 +173,19 @@ def pq_compute_multi_core(matched_annotations_list, gt_folder, pred_folder, gt_folder (str): The path of the ground truth images. pred_folder (str): The path of the prediction images. categories (str): The categories of the dataset. - file_client (object): The file client of the dataset. + file_client (object): The file client of the dataset. If None, + the backend will be set to `disk`. """ if PQStat is None: raise RuntimeError( 'panopticapi is not installed, please install it by: ' 'pip install git+https://github.com/cocodataset/' 'panopticapi.git.') + + if file_client is None: + file_client_args = dict(backend='disk') + file_client = mmcv.FileClient(**file_client_args) + cpu_num = multiprocessing.cpu_count() annotations_split = np.array_split(matched_annotations_list, cpu_num) print('Number of cores: {}, images per core: {}'.format( diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index eea1453364d..b730ecc2b9e 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -78,18 +78,15 @@ def __init__(self, # join paths if data_root is specified if self.data_root is not None: if not osp.isabs(self.ann_file): - self.ann_file = self.file_client.join_path( - self.data_root, self.ann_file) + self.ann_file = osp.join(self.data_root, self.ann_file) if not (self.img_prefix is None or osp.isabs(self.img_prefix)): - self.img_prefix = self.file_client.join_path( - self.data_root, self.img_prefix) + self.img_prefix = osp.join(self.data_root, self.img_prefix) if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): - self.seg_prefix = self.file_client.join_path( - self.data_root, self.seg_prefix) + self.seg_prefix = osp.join(self.data_root, self.seg_prefix) if not (self.proposal_file is None or osp.isabs(self.proposal_file)): - self.proposal_file = self.file_client.join_path( - self.data_root, self.proposal_file) + self.proposal_file = osp.join(self.data_root, + self.proposal_file) # load annotations (and proposals) with self.file_client.get_local_path(self.ann_file) as local_path: self.data_infos = self.load_annotations(local_path) diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index a7f1974a025..4a6e043e8e1 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -231,7 +231,7 @@ def __init__(self, self.with_seg = with_seg self.poly2mask = poly2mask self.file_client_args = file_client_args.copy() - self.file_client = mmcv.FileClient(**self.file_client_args) + self.file_client = None def _load_bboxes(self, results): """Private function to load bounding box annotations. @@ -344,6 +344,10 @@ def _load_semantic_seg(self, results): Returns: dict: The dict contains loaded semantic segmentation annotations. """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + filename = osp.join(results['seg_prefix'], results['ann_info']['seg_map']) img_bytes = self.file_client.get(filename) @@ -434,6 +438,10 @@ def _load_masks_and_semantic_segs(self, results): dict: The dict contains loaded mask and semantic segmentation annotations. `BitmapMasks` is used for mask annotations. """ + + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + filename = osp.join(results['seg_prefix'], results['ann_info']['seg_map']) img_bytes = self.file_client.get(filename) From 1ece2e87f978be5e838965ce008b1666b539e9a6 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Fri, 19 Nov 2021 15:32:01 +0800 Subject: [PATCH 7/8] Add comments --- mmdet/datasets/api_wrappers/panoptic_evaluation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py index 86e84c858c2..b595aaf8f36 100644 --- a/mmdet/datasets/api_wrappers/panoptic_evaluation.py +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -52,11 +52,14 @@ def pq_compute_single_core(proc_id, print('Core: {}, {} from {} images processed'.format( proc_id, idx, len(annotation_set))) idx += 1 + # The gt images can be on the local disk or `ceph`, so we use + # file_client here. img_bytes = file_client.get( os.path.join(gt_folder, gt_ann['file_name'])) pan_gt = mmcv.imfrombytes(img_bytes, flag='color', channel_order='rgb') pan_gt = rgb2id(pan_gt) + # The predictions can only be on the local dist now. pan_pred = mmcv.imread( os.path.join(pred_folder, pred_ann['file_name']), flag='color', From 8f9df510e087c1174d3768f02d9a824f842fae59 Mon Sep 17 00:00:00 2001 From: AronLin <347630870@qq.com> Date: Fri, 19 Nov 2021 20:28:41 +0800 Subject: [PATCH 8/8] Add unit test for pq_compute_single_core --- mmdet/datasets/api_wrappers/__init__.py | 6 ++-- .../api_wrappers/panoptic_evaluation.py | 6 ++++ .../test_datasets/test_panoptic_dataset.py | 32 +++++++++++++++++++ 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/mmdet/datasets/api_wrappers/__init__.py b/mmdet/datasets/api_wrappers/__init__.py index 6dc535a113c..af8557593b6 100644 --- a/mmdet/datasets/api_wrappers/__init__.py +++ b/mmdet/datasets/api_wrappers/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .coco_api import COCO, COCOeval -from .panoptic_evaluation import pq_compute_multi_core +from .panoptic_evaluation import pq_compute_multi_core, pq_compute_single_core -__all__ = ['COCO', 'COCOeval', 'pq_compute_multi_core'] +__all__ = [ + 'COCO', 'COCOeval', 'pq_compute_multi_core', 'pq_compute_single_core' +] diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py index b595aaf8f36..1a21fe8f098 100644 --- a/mmdet/datasets/api_wrappers/panoptic_evaluation.py +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -40,6 +40,12 @@ def pq_compute_single_core(proc_id, file_client (object): The file client of the dataset. If None, the backend will be set to `disk`. """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + if file_client is None: file_client_args = dict(backend='disk') file_client = mmcv.FileClient(**file_client_args) diff --git a/tests/test_data/test_datasets/test_panoptic_dataset.py b/tests/test_data/test_datasets/test_panoptic_dataset.py index 44670f22f07..fd571d219d1 100644 --- a/tests/test_data/test_datasets/test_panoptic_dataset.py +++ b/tests/test_data/test_datasets/test_panoptic_dataset.py @@ -5,6 +5,7 @@ import mmcv import numpy as np +from mmdet.datasets.api_wrappers import pq_compute_single_core from mmdet.datasets.coco_panoptic import INSTANCE_OFFSET, CocoPanopticDataset try: @@ -305,3 +306,34 @@ def test_panoptic_evaluation(): assert np.isclose(parsed_results['PQ_st'], 82.701) assert np.isclose(parsed_results['SQ_st'], 82.701) assert np.isclose(parsed_results['RQ_st'], 100.000) + + # test the api wrapper of `pq_compute_single_core` + # Codes are copied from `coco_panoptic.py` and modified + result_files, _ = dataset.format_results( + results, jsonfile_prefix=outfile_prefix) + + imgs = dataset.coco.imgs + gt_json = dataset.coco.img_ann_map # image to annotations + gt_json = [{ + 'image_id': k, + 'segments_info': v, + 'file_name': imgs[k]['segm_file'] + } for k, v in gt_json.items()] + pred_json = mmcv.load(result_files['panoptic']) + pred_json = dict((el['image_id'], el) for el in pred_json['annotations']) + + # match the gt_anns and pred_anns in the same image + matched_annotations_list = [] + for gt_ann in gt_json: + img_id = gt_ann['image_id'] + matched_annotations_list.append((gt_ann, pred_json[img_id])) + gt_folder = dataset.seg_prefix + pred_folder = osp.join(osp.dirname(outfile_prefix), 'panoptic') + + pq_stat = pq_compute_single_core(0, matched_annotations_list, gt_folder, + pred_folder, dataset.categories) + pq_all = pq_stat.pq_average(dataset.categories, isthing=None)[0] + assert np.isclose(pq_all['pq'] * 100, 67.869) + assert np.isclose(pq_all['sq'] * 100, 80.898) + assert np.isclose(pq_all['rq'] * 100, 83.333) + assert pq_all['n'] == 3