Skip to content

Commit

Permalink
[Enhance] Add instance evalutation for coco_panoptic (#7313)
Browse files Browse the repository at this point in the history
update comments

rename function and replace condition

rename

add message for proposal_fast when instance segmentation evaluation

set cocoGt as arg

update comments

update comments

update docstring and rename

add unit test

update docstring

add assert for instance eval
  • Loading branch information
chhluo authored Mar 16, 2022
1 parent 46fc0c8 commit ab662f9
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 34 deletions.
14 changes: 14 additions & 0 deletions mmdet/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def single_gpu_test(model,
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
# This logic is only used in panoptic segmentation test.
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (bbox_results,
encode_mask_results(mask_results))

results.extend(result)

for _ in range(batch_size):
Expand Down Expand Up @@ -104,6 +111,13 @@ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
if isinstance(result[0], tuple):
result = [(bbox_results, encode_mask_results(mask_results))
for bbox_results, mask_results in result]
# This logic is only used in panoptic segmentation test.
elif isinstance(result[0], dict) and 'ins_results' in result[0]:
for j in range(len(result)):
bbox_results, mask_results = result[j]['ins_results']
result[j]['ins_results'] = (
bbox_results, encode_mask_results(mask_results))

results.extend(result)

if rank == 0:
Expand Down
101 changes: 79 additions & 22 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,19 +383,24 @@ def format_results(self, results, jsonfile_prefix=None, **kwargs):
result_files = self.results2json(results, jsonfile_prefix)
return result_files, tmp_dir

def evaluate(self,
results,
metric='bbox',
logger=None,
jsonfile_prefix=None,
classwise=False,
proposal_nums=(100, 300, 1000),
iou_thrs=None,
metric_items=None):
"""Evaluation in COCO protocol.
def evaluate_det_segm(self,
results,
result_files,
coco_gt,
metrics,
logger=None,
classwise=False,
proposal_nums=(100, 300, 1000),
iou_thrs=None,
metric_items=None):
"""Instance segmentation and object detection evaluation in COCO
protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
results (list[list | tuple | dict]): Testing results of the
dataset.
result_files (dict[str, str]): a dict contains json file path.
coco_gt (COCO): COCO API object with ground truth annotation.
metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing
Expand All @@ -422,30 +427,24 @@ def evaluate(self,
Returns:
dict[str, float]: COCO style evaluation metric.
"""

metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
if iou_thrs is None:
iou_thrs = np.linspace(
.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
if metric_items is not None:
if not isinstance(metric_items, list):
metric_items = [metric_items]

result_files, tmp_dir = self.format_results(results, jsonfile_prefix)

eval_results = OrderedDict()
cocoGt = self.coco
for metric in metrics:
msg = f'Evaluating {metric}...'
if logger is None:
msg = '\n' + msg
print_log(msg, logger=logger)

if metric == 'proposal_fast':
if isinstance(results[0], tuple):
raise KeyError('proposal_fast is not supported for '
'instance segmentation result.')
ar = self.fast_eval_recall(
results, proposal_nums, iou_thrs, logger='silent')
log_msg = []
Expand Down Expand Up @@ -476,15 +475,15 @@ def evaluate(self,
'of small/medium/large instances since v2.12.0. This '
'does not change the overall mAP calculation.',
UserWarning)
cocoDt = cocoGt.loadRes(predictions)
coco_det = coco_gt.loadRes(predictions)
except IndexError:
print_log(
'The testing results of the whole dataset is empty.',
logger=logger,
level=logging.ERROR)
break

cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
cocoEval = COCOeval(coco_gt, coco_det, iou_type)
cocoEval.params.catIds = self.cat_ids
cocoEval.params.imgIds = self.img_ids
cocoEval.params.maxDets = list(proposal_nums)
Expand Down Expand Up @@ -590,6 +589,64 @@ def evaluate(self,
eval_results[f'{metric}_mAP_copypaste'] = (
f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
f'{ap[4]:.3f} {ap[5]:.3f}')

return eval_results

def evaluate(self,
results,
metric='bbox',
logger=None,
jsonfile_prefix=None,
classwise=False,
proposal_nums=(100, 300, 1000),
iou_thrs=None,
metric_items=None):
"""Evaluation in COCO protocol.
Args:
results (list[list | tuple]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Options are
'bbox', 'segm', 'proposal', 'proposal_fast'.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes
the file path and the prefix of filename, e.g., "a/b/prefix".
If not specified, a temp file will be created. Default: None.
classwise (bool): Whether to evaluating the AP for each class.
proposal_nums (Sequence[int]): Proposal number used for evaluating
recalls, such as recall@100, recall@1000.
Default: (100, 300, 1000).
iou_thrs (Sequence[float], optional): IoU threshold used for
evaluating recalls/mAPs. If set to a list, the average of all
IoUs will also be computed. If not specified, [0.50, 0.55,
0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
Default: None.
metric_items (list[str] | str, optional): Metric items that will
be returned. If not specified, ``['AR@100', 'AR@300',
'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
``metric=='bbox' or metric=='segm'``.
Returns:
dict[str, float]: COCO style evaluation metric.
"""

metrics = metric if isinstance(metric, list) else [metric]
allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')

coco_gt = self.coco
self.cat_ids = coco_gt.get_cat_ids(cat_names=self.CLASSES)

result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
eval_results = self.evaluate_det_segm(results, result_files, coco_gt,
metrics, logger, classwise,
proposal_nums, iou_thrs,
metric_items)

if tmp_dir is not None:
tmp_dir.cleanup()
return eval_results
117 changes: 105 additions & 12 deletions mmdet/datasets/coco_panoptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,29 @@ class CocoPanopticDataset(CocoDataset):
},
...
]
Args:
ann_file (str): Panoptic segmentation annotation file path.
pipeline (list[dict]): Processing pipeline.
ins_ann_file (str): Instance segmentation annotation file path.
Defaults to None.
classes (str | Sequence[str], optional): Specify classes to load.
If is None, ``cls.CLASSES`` will be used. Defaults to None.
data_root (str, optional): Data root for ``ann_file``,
``ins_ann_file`` ``img_prefix``, ``seg_prefix``, ``proposal_file``
if specified. Defaults to None.
img_prefix (str, optional): Prefix of path to images. Defaults to ''.
seg_prefix (str, optional): Prefix of path to segmentation files.
Defaults to None.
proposal_file (str, optional): Path to proposal file. Defaults to None.
test_mode (bool, optional): If set True, annotation will not be loaded.
Defaults to False.
filter_empty_gt (bool, optional): If set true, images without bounding
boxes of the dataset's classes will be filtered out. This option
only works when `test_mode=False`, i.e., we never filter images
during tests. Defaults to True.
file_client_args (:obj:`mmcv.ConfigDict` | dict): file client args.
Defaults to dict(backend='disk').
"""
CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
Expand Down Expand Up @@ -233,6 +256,31 @@ class CocoPanopticDataset(CocoDataset):
(206, 186, 171), (152, 161, 64), (116, 112, 0), (0, 114, 143),
(102, 102, 156), (250, 141, 255)]

def __init__(self,
ann_file,
pipeline,
ins_ann_file=None,
classes=None,
data_root=None,
img_prefix='',
seg_prefix=None,
proposal_file=None,
test_mode=False,
filter_empty_gt=True,
file_client_args=dict(backend='disk')):
super().__init__(
ann_file,
pipeline,
classes=classes,
data_root=data_root,
img_prefix=img_prefix,
seg_prefix=seg_prefix,
proposal_file=proposal_file,
test_mode=test_mode,
filter_empty_gt=filter_empty_gt,
file_client_args=file_client_args)
self.ins_ann_file = ins_ann_file

def load_annotations(self, ann_file):
"""Load annotation from COCO Panoptic style annotation file.
Expand Down Expand Up @@ -402,23 +450,41 @@ def _pan2json(self, results, outfile_prefix):
return pan_json_results

def results2json(self, results, outfile_prefix):
"""Dump the panoptic results to a COCO panoptic style json file.
"""Dump the results to a COCO style json file.
There are 4 types of results: proposals, bbox predictions, mask
predictions, panoptic segmentation predictions, and they have
different data types. This method will automatically recognize
the type, and dump them to json files.
Args:
results (dict): Testing results of the dataset.
outfile_prefix (str): The filename prefix of the json files. If the
prefix is "somepath/xxx", the json files will be named
"somepath/xxx.panoptic.json"
"somepath/xxx.panoptic.json", "somepath/xxx.bbox.json",
"somepath/xxx.segm.json"
Returns:
dict[str: str]: The key is 'panoptic' and the value is
corresponding filename.
dict[str: str]: Possible keys are "panoptic", "bbox", "segm", \
"proposal", and values are corresponding filenames.
"""
result_files = dict()
pan_results = [result['pan_results'] for result in results]
pan_json_results = self._pan2json(pan_results, outfile_prefix)
result_files['panoptic'] = f'{outfile_prefix}.panoptic.json'
mmcv.dump(pan_json_results, result_files['panoptic'])
# panoptic segmentation results
if 'pan_results' in results[0]:
pan_results = [result['pan_results'] for result in results]
pan_json_results = self._pan2json(pan_results, outfile_prefix)
result_files['panoptic'] = f'{outfile_prefix}.panoptic.json'
mmcv.dump(pan_json_results, result_files['panoptic'])

# instance segmentation results
if 'ins_results' in results[0]:
ins_results = [result['ins_results'] for result in results]
bbox_json_results, segm_json_results = self._segm2json(ins_results)
result_files['bbox'] = f'{outfile_prefix}.bbox.json'
result_files['proposal'] = f'{outfile_prefix}.bbox.json'
result_files['segm'] = f'{outfile_prefix}.segm.json'
mmcv.dump(bbox_json_results, result_files['bbox'])
mmcv.dump(segm_json_results, result_files['segm'])

return result_files

Expand Down Expand Up @@ -476,8 +542,16 @@ def evaluate_pan_json(self,
for k, v in zip(self.CLASSES, pq_results['classwise'].values())
}
print_panoptic_table(pq_results, classwise_results, logger=logger)
results = parse_pq_results(pq_results)
results['PQ_copypaste'] = (
f'{results["PQ"]:.3f} {results["SQ"]:.3f} '
f'{results["RQ"]:.3f} '
f'{results["PQ_th"]:.3f} {results["SQ_th"]:.3f} '
f'{results["RQ_th"]:.3f} '
f'{results["PQ_st"]:.3f} {results["SQ_st"]:.3f} '
f'{results["RQ_st"]:.3f}')

return parse_pq_results(pq_results)
return results

def evaluate(self,
results,
Expand All @@ -491,8 +565,8 @@ def evaluate(self,
Args:
results (list[dict]): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. Only
support 'PQ' at present. 'pq' will be regarded as 'PQ.
metric (str | list[str]): Metrics to be evaluated. 'PQ', 'bbox',
'segm', 'proposal' are supported. 'pq' will be regarded as 'PQ.
logger (logging.Logger | str | None): Logger used for printing
related information during evaluation. Default: None.
jsonfile_prefix (str | None): The prefix of json files. It includes
Expand All @@ -510,7 +584,7 @@ def evaluate(self,
metrics = metric if isinstance(metric, list) else [metric]
# Compatible with lowercase 'pq'
metrics = ['PQ' if metric == 'pq' else metric for metric in metrics]
allowed_metrics = ['PQ'] # todo: support other metrics like 'bbox'
allowed_metrics = ['PQ', 'bbox', 'segm', 'proposal']
for metric in metrics:
if metric not in allowed_metrics:
raise KeyError(f'metric {metric} is not supported')
Expand All @@ -524,6 +598,25 @@ def evaluate(self,
eval_pan_results = self.evaluate_pan_json(
result_files, outfile_prefix, logger, classwise, nproc=nproc)
eval_results.update(eval_pan_results)
metrics.remove('PQ')

if (('bbox' in metrics) or ('segm' in metrics)
or ('proposal' in metrics)):

assert 'ins_results' in results[0], 'instance segmentation' \
'results are absent from results'

assert self.ins_ann_file is not None, 'Annotation '\
'file for instance segmentation or object detection ' \
'shuold not be None'

coco_gt = COCO(self.ins_ann_file)
self.cat_ids = coco_gt.get_cat_ids(cat_names=self.THING_CLASSES)

eval_ins_results = self.evaluate_det_segm(results, result_files,
coco_gt, metrics, logger,
classwise, **kwargs)
eval_results.update(eval_ins_results)

if tmp_dir is not None:
tmp_dir.cleanup()
Expand Down
Loading

0 comments on commit ab662f9

Please sign in to comment.