Skip to content

Commit

Permalink
Merge pull request open-mmlab#39 from opencv/feature/ik/concat_coco
Browse files Browse the repository at this point in the history
Feature/ik/concat coco
  • Loading branch information
Ilya-Krylov authored May 28, 2020
2 parents e13a2a1 + eee4cea commit 6af96ce
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 22 deletions.
48 changes: 30 additions & 18 deletions mmdet/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,29 @@

from .dataset_wrappers import ConcatDataset, RepeatDataset
from .registry import DATASETS
from .coco import CocoDataset, ConcatenatedCocoDataset


def get_image_prefixes_auto(cfg, ann_files):
del cfg['img_prefix_auto']
assert cfg.get('img_prefix', None) is None
if cfg['type'] == 'CustomCocoDataset':
# assuming following dataset structure:
# dataset_root
# ├── annotations
# │ ├── instances_train.json
# │ ├── ...
# ├── images
# ├── image_name1
# ├── image_name2
# ├── ...
# and file_name inside instances_train.json is relative to <dataset_root>/images
img_prefixes = \
[os.path.join(os.path.dirname(ann_file), '..', 'images') for ann_file in ann_files]
else:
raise NotImplementedError

return img_prefixes


def _concat_dataset(cfg, default_args=None):
Expand All @@ -15,23 +38,7 @@ def _concat_dataset(cfg, default_args=None):
proposal_files = cfg.get('proposal_file', None)

if cfg.get('img_prefix_auto', False):
del cfg['img_prefix_auto']
assert img_prefixes is None
if cfg['type'] == 'CustomCocoDataset':
# assuming following dataset structure:
# dataset_root
# ├── annotations
# │ ├── instances_train.json
# │ ├── ...
# ├── images
# ├── image_name1
# ├── image_name2
# ├── ...
# and file_name inside instances_train.json is relative to <dataset_root>/images
img_prefixes = \
[os.path.join(os.path.dirname(ann_file), '..', 'images') for ann_file in ann_files]
else:
raise NotImplementedError
img_prefixes = get_image_prefixes_auto(cfg, ann_files)

datasets = []
num_dset = len(ann_files)
Expand All @@ -46,7 +53,10 @@ def _concat_dataset(cfg, default_args=None):
data_cfg['proposal_file'] = proposal_files[i]
datasets.append(build_dataset(data_cfg, default_args))

return ConcatDataset(datasets)
concatenated_dataset = ConcatDataset(datasets)
if all(isinstance(dataset, CocoDataset) for dataset in concatenated_dataset.datasets):
concatenated_dataset = ConcatenatedCocoDataset(concatenated_dataset)
return concatenated_dataset


def build_dataset(cfg, default_args=None):
Expand All @@ -64,6 +74,8 @@ def build_dataset(cfg, default_args=None):
f'{cfg["ann_file"]}')
cfg['ann_file'] = matches
if len(cfg['ann_file']) == 1:
if cfg.get('img_prefix_auto', False):
cfg['img_prefix'] = get_image_prefixes_auto(cfg, cfg['ann_file'])[0]
cfg['ann_file'] = cfg['ann_file'][0]
dataset = build_from_cfg(cfg, DATASETS, default_args)
else:
Expand Down
72 changes: 72 additions & 0 deletions mmdet/datasets/coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import os
from pycocotools.coco import COCO

from .custom import CustomDataset
Expand Down Expand Up @@ -108,3 +109,74 @@ def _parse_ann_info(self, img_info, ann_info):
seg_map=seg_map)

return ann


@DATASETS.register_module
class ConcatenatedCocoDataset(CocoDataset):
def __init__(self, concatenated_dataset):
for dataset in concatenated_dataset.datasets:
assert isinstance(dataset, CocoDataset), type(dataset)
assert dataset.cat_ids == concatenated_dataset.datasets[0].cat_ids
assert dataset.cat2label == concatenated_dataset.datasets[0].cat2label
assert str(dataset.pipeline) == str(concatenated_dataset.datasets[0].pipeline), f'{dataset.pipeline}'
assert dataset.proposals == concatenated_dataset.datasets[0].proposals

self.CLASSES = concatenated_dataset.datasets[0].CLASSES
self.test_mode = concatenated_dataset.datasets[0].test_mode
self.filter_empty_gt = concatenated_dataset.datasets[0].filter_empty_gt
self.cat_ids = concatenated_dataset.datasets[0].cat_ids
self.cat2label = concatenated_dataset.datasets[0].cat2label
self.pipeline = concatenated_dataset.datasets[0].pipeline
self.proposals = concatenated_dataset.datasets[0].proposals
self.img_ids = []
self.img_infos = []
self.flag = None
self.ann_infos = []
self.img_prefix = None
self.seg_prefix = None
self.proposal_file = None
self.coco = None

for dataset in concatenated_dataset.datasets:
img_shift = 0 if not self.img_ids else max(self.img_ids) + 1

for img_id in dataset.img_ids:
self.img_ids.append(img_id + img_shift)

for im_info in dataset.img_infos:
im_info = im_info
im_info['id'] += img_shift
im_info['filename'] = os.path.join(dataset.img_prefix, im_info['filename'])
self.img_infos.append(im_info)

if self.coco is None:
self.coco = dataset.coco
self.coco.dataset = {'images': dataset.coco.dataset['images'],
'categories': dataset.coco.dataset['categories']}
else:
for cat in dataset.coco.catToImgs:
self.coco.catToImgs[cat].extend([img_id + img_shift for img_id in dataset.coco.catToImgs[cat]])

ann_shift = max(self.coco.anns) + 1
for k, v in dataset.coco.anns.items():
v['image_id'] += img_shift
v['id'] += ann_shift
self.coco.anns[k + ann_shift] = v

for k, v in dataset.coco.imgs.items():
v['id'] += img_shift
self.coco.imgs[k + img_shift] = v

for k, v in dataset.coco.imgToAnns.items():
# indices in annotations have been changed above
self.coco.imgToAnns[k + img_shift] = v

for v in dataset.coco.dataset['images']:
v['id'] += img_shift
self.coco.dataset['images'].append(v)

if hasattr(dataset, 'flag'):
if self.flag is None:
self.flag = dataset.flag
else:
self.flag = np.concatenate(dataset.flag, axis=0)
8 changes: 4 additions & 4 deletions mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,8 @@ def __repr__(self):
repr_str = self.__class__.__name__
repr_str += ('(brightness_delta={}, contrast_range={}, '
'saturation_range={}, hue_delta={})').format(
self.brightness_delta, self.contrast_range,
self.saturation_range, self.hue_delta)
self.brightness_delta, (self.contrast_lower, self.contrast_upper),
(self.saturation_lower, self.saturation_upper), self.hue_delta)
return repr_str


Expand Down Expand Up @@ -691,8 +691,8 @@ def __call__(self, results):

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += '(min_ious={}, min_crop_size={})'.format(
self.min_ious, self.min_crop_size)
repr_str += '(sample_mode={}, min_crop_size={})'.format(
self.sample_mode, self.min_crop_size)
return repr_str


Expand Down

0 comments on commit 6af96ce

Please sign in to comment.