diff --git a/mmdet/datasets/api_wrappers/__init__.py b/mmdet/datasets/api_wrappers/__init__.py index 9bf807107a4..af8557593b6 100644 --- a/mmdet/datasets/api_wrappers/__init__.py +++ b/mmdet/datasets/api_wrappers/__init__.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .coco_api import COCO, COCOeval +from .panoptic_evaluation import pq_compute_multi_core, pq_compute_single_core -__all__ = ['COCO', 'COCOeval'] +__all__ = [ + 'COCO', 'COCOeval', 'pq_compute_multi_core', 'pq_compute_single_core' +] diff --git a/mmdet/datasets/api_wrappers/panoptic_evaluation.py b/mmdet/datasets/api_wrappers/panoptic_evaluation.py new file mode 100644 index 00000000000..1a21fe8f098 --- /dev/null +++ b/mmdet/datasets/api_wrappers/panoptic_evaluation.py @@ -0,0 +1,212 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +# Copyright (c) 2018, Alexander Kirillov +# This file supports `file_client` for `panopticapi`, +# the source code is copied from `panopticapi`, +# only the way to load the gt images is modified. +import multiprocessing +import os + +import mmcv +import numpy as np + +try: + from panopticapi.evaluation import PQStat, VOID, OFFSET + from panopticapi.utils import rgb2id +except ImportError: + PQStat = None + rgb2id = None + VOID = 0 + OFFSET = 256 * 256 * 256 + + +def pq_compute_single_core(proc_id, + annotation_set, + gt_folder, + pred_folder, + categories, + file_client=None): + """The single core function to evaluate the metric of Panoptic + Segmentation. + + Same as the function with the same name in `panopticapi`. Only the function + to load the images is changed to use the file client. + + Args: + proc_id (int): The id of the mini process. + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + file_client (object): The file client of the dataset. If None, + the backend will be set to `disk`. + """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + if file_client is None: + file_client_args = dict(backend='disk') + file_client = mmcv.FileClient(**file_client_args) + + pq_stat = PQStat() + + idx = 0 + for gt_ann, pred_ann in annotation_set: + if idx % 100 == 0: + print('Core: {}, {} from {} images processed'.format( + proc_id, idx, len(annotation_set))) + idx += 1 + # The gt images can be on the local disk or `ceph`, so we use + # file_client here. + img_bytes = file_client.get( + os.path.join(gt_folder, gt_ann['file_name'])) + pan_gt = mmcv.imfrombytes(img_bytes, flag='color', channel_order='rgb') + pan_gt = rgb2id(pan_gt) + + # The predictions can only be on the local dist now. + pan_pred = mmcv.imread( + os.path.join(pred_folder, pred_ann['file_name']), + flag='color', + channel_order='rgb') + pan_pred = rgb2id(pan_pred) + + gt_segms = {el['id']: el for el in gt_ann['segments_info']} + pred_segms = {el['id']: el for el in pred_ann['segments_info']} + + # predicted segments area calculation + prediction sanity checks + pred_labels_set = set(el['id'] for el in pred_ann['segments_info']) + labels, labels_cnt = np.unique(pan_pred, return_counts=True) + for label, label_cnt in zip(labels, labels_cnt): + if label not in pred_segms: + if label == VOID: + continue + raise KeyError( + 'In the image with ID {} segment with ID {} is ' + 'presented in PNG and not presented in JSON.'.format( + gt_ann['image_id'], label)) + pred_segms[label]['area'] = label_cnt + pred_labels_set.remove(label) + if pred_segms[label]['category_id'] not in categories: + raise KeyError( + 'In the image with ID {} segment with ID {} has ' + 'unknown category_id {}.'.format( + gt_ann['image_id'], label, + pred_segms[label]['category_id'])) + if len(pred_labels_set) != 0: + raise KeyError( + 'In the image with ID {} the following segment IDs {} ' + 'are presented in JSON and not presented in PNG.'.format( + gt_ann['image_id'], list(pred_labels_set))) + + # confusion matrix calculation + pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype( + np.uint64) + gt_pred_map = {} + labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True) + for label, intersection in zip(labels, labels_cnt): + gt_id = label // OFFSET + pred_id = label % OFFSET + gt_pred_map[(gt_id, pred_id)] = intersection + + # count all matched pairs + gt_matched = set() + pred_matched = set() + for label_tuple, intersection in gt_pred_map.items(): + gt_label, pred_label = label_tuple + if gt_label not in gt_segms: + continue + if pred_label not in pred_segms: + continue + if gt_segms[gt_label]['iscrowd'] == 1: + continue + if gt_segms[gt_label]['category_id'] != pred_segms[pred_label][ + 'category_id']: + continue + + union = pred_segms[pred_label]['area'] + gt_segms[gt_label][ + 'area'] - intersection - gt_pred_map.get((VOID, pred_label), 0) + iou = intersection / union + if iou > 0.5: + pq_stat[gt_segms[gt_label]['category_id']].tp += 1 + pq_stat[gt_segms[gt_label]['category_id']].iou += iou + gt_matched.add(gt_label) + pred_matched.add(pred_label) + + # count false positives + crowd_labels_dict = {} + for gt_label, gt_info in gt_segms.items(): + if gt_label in gt_matched: + continue + # crowd segments are ignored + if gt_info['iscrowd'] == 1: + crowd_labels_dict[gt_info['category_id']] = gt_label + continue + pq_stat[gt_info['category_id']].fn += 1 + + # count false positives + for pred_label, pred_info in pred_segms.items(): + if pred_label in pred_matched: + continue + # intersection of the segment with VOID + intersection = gt_pred_map.get((VOID, pred_label), 0) + # plus intersection with corresponding CROWD region if it exists + if pred_info['category_id'] in crowd_labels_dict: + intersection += gt_pred_map.get( + (crowd_labels_dict[pred_info['category_id']], pred_label), + 0) + # predicted segment is ignored if more than half of + # the segment correspond to VOID and CROWD regions + if intersection / pred_info['area'] > 0.5: + continue + pq_stat[pred_info['category_id']].fp += 1 + print('Core: {}, all {} images processed'.format(proc_id, + len(annotation_set))) + return pq_stat + + +def pq_compute_multi_core(matched_annotations_list, + gt_folder, + pred_folder, + categories, + file_client=None): + """Evaluate the metrics of Panoptic Segmentation with multithreading. + + Same as the function with the same name in `panopticapi`. + + Args: + matched_annotations_list (list): The matched annotation list. Each + element is a tuple of annotations of the same image with the + format (gt_anns, pred_anns). + gt_folder (str): The path of the ground truth images. + pred_folder (str): The path of the prediction images. + categories (str): The categories of the dataset. + file_client (object): The file client of the dataset. If None, + the backend will be set to `disk`. + """ + if PQStat is None: + raise RuntimeError( + 'panopticapi is not installed, please install it by: ' + 'pip install git+https://github.com/cocodataset/' + 'panopticapi.git.') + + if file_client is None: + file_client_args = dict(backend='disk') + file_client = mmcv.FileClient(**file_client_args) + + cpu_num = multiprocessing.cpu_count() + annotations_split = np.array_split(matched_annotations_list, cpu_num) + print('Number of cores: {}, images per core: {}'.format( + cpu_num, len(annotations_split[0]))) + workers = multiprocessing.Pool(processes=cpu_num) + processes = [] + for proc_id, annotation_set in enumerate(annotations_split): + p = workers.apply_async(pq_compute_single_core, + (proc_id, annotation_set, gt_folder, + pred_folder, categories, file_client)) + processes.append(p) + pq_stat = PQStat() + for p in processes: + pq_stat += p.get() + return pq_stat diff --git a/mmdet/datasets/coco_panoptic.py b/mmdet/datasets/coco_panoptic.py index 36e95950d97..e2b5c636409 100644 --- a/mmdet/datasets/coco_panoptic.py +++ b/mmdet/datasets/coco_panoptic.py @@ -8,17 +8,16 @@ from mmcv.utils import print_log from terminaltables import AsciiTable -from .api_wrappers import COCO +from .api_wrappers import COCO, pq_compute_multi_core from .builder import DATASETS from .coco import CocoDataset try: import panopticapi - from panopticapi.evaluation import pq_compute_multi_core, VOID + from panopticapi.evaluation import VOID from panopticapi.utils import id2rgb except ImportError: panopticapi = None - pq_compute_multi_core = None id2rgb = None VOID = None @@ -421,7 +420,8 @@ def evaluate_pan_json(self, pred_folder = os.path.join(os.path.dirname(outfile_prefix), 'panoptic') pq_stat = pq_compute_multi_core(matched_annotations_list, gt_folder, - pred_folder, self.categories) + pred_folder, self.categories, + self.file_client) metrics = [('All', None), ('Things', True), ('Stuff', False)] pq_results = {} diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py index 0e82fa396ad..b730ecc2b9e 100644 --- a/mmdet/datasets/custom.py +++ b/mmdet/datasets/custom.py @@ -63,7 +63,8 @@ def __init__(self, seg_prefix=None, proposal_file=None, test_mode=False, - filter_empty_gt=True): + filter_empty_gt=True, + file_client_args=dict(backend='disk')): self.ann_file = ann_file self.data_root = data_root self.img_prefix = img_prefix @@ -72,6 +73,7 @@ def __init__(self, self.test_mode = test_mode self.filter_empty_gt = filter_empty_gt self.CLASSES = self.get_classes(classes) + self.file_client = mmcv.FileClient(**file_client_args) # join paths if data_root is specified if self.data_root is not None: @@ -86,10 +88,13 @@ def __init__(self, self.proposal_file = osp.join(self.data_root, self.proposal_file) # load annotations (and proposals) - self.data_infos = self.load_annotations(self.ann_file) + with self.file_client.get_local_path(self.ann_file) as local_path: + self.data_infos = self.load_annotations(local_path) if self.proposal_file is not None: - self.proposals = self.load_proposals(self.proposal_file) + with self.file_client.get_local_path( + self.proposal_file) as local_path: + self.proposals = self.load_proposals(local_path) else: self.proposals = None diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py index 1f8ba37e1a5..4a6e043e8e1 100644 --- a/mmdet/datasets/pipelines/loading.py +++ b/mmdet/datasets/pipelines/loading.py @@ -441,9 +441,9 @@ def _load_masks_and_semantic_segs(self, results): if self.file_client is None: self.file_client = mmcv.FileClient(**self.file_client_args) + filename = osp.join(results['seg_prefix'], results['ann_info']['seg_map']) - img_bytes = self.file_client.get(filename) pan_png = mmcv.imfrombytes( img_bytes, flag='color', channel_order='rgb').squeeze() diff --git a/tests/test_data/test_datasets/test_panoptic_dataset.py b/tests/test_data/test_datasets/test_panoptic_dataset.py index 44670f22f07..fd571d219d1 100644 --- a/tests/test_data/test_datasets/test_panoptic_dataset.py +++ b/tests/test_data/test_datasets/test_panoptic_dataset.py @@ -5,6 +5,7 @@ import mmcv import numpy as np +from mmdet.datasets.api_wrappers import pq_compute_single_core from mmdet.datasets.coco_panoptic import INSTANCE_OFFSET, CocoPanopticDataset try: @@ -305,3 +306,34 @@ def test_panoptic_evaluation(): assert np.isclose(parsed_results['PQ_st'], 82.701) assert np.isclose(parsed_results['SQ_st'], 82.701) assert np.isclose(parsed_results['RQ_st'], 100.000) + + # test the api wrapper of `pq_compute_single_core` + # Codes are copied from `coco_panoptic.py` and modified + result_files, _ = dataset.format_results( + results, jsonfile_prefix=outfile_prefix) + + imgs = dataset.coco.imgs + gt_json = dataset.coco.img_ann_map # image to annotations + gt_json = [{ + 'image_id': k, + 'segments_info': v, + 'file_name': imgs[k]['segm_file'] + } for k, v in gt_json.items()] + pred_json = mmcv.load(result_files['panoptic']) + pred_json = dict((el['image_id'], el) for el in pred_json['annotations']) + + # match the gt_anns and pred_anns in the same image + matched_annotations_list = [] + for gt_ann in gt_json: + img_id = gt_ann['image_id'] + matched_annotations_list.append((gt_ann, pred_json[img_id])) + gt_folder = dataset.seg_prefix + pred_folder = osp.join(osp.dirname(outfile_prefix), 'panoptic') + + pq_stat = pq_compute_single_core(0, matched_annotations_list, gt_folder, + pred_folder, dataset.categories) + pq_all = pq_stat.pq_average(dataset.categories, isthing=None)[0] + assert np.isclose(pq_all['pq'] * 100, 67.869) + assert np.isclose(pq_all['sq'] * 100, 80.898) + assert np.isclose(pq_all['rq'] * 100, 83.333) + assert pq_all['n'] == 3