Skip to content

Commit

Permalink
[Feature] support mscoco dataset (#1520)
Browse files Browse the repository at this point in the history
* support reading mscoco dataset with caption

* fix lint

* add unit test

* fix isort

* fix lint

* fix lint
  • Loading branch information
plyfager authored Dec 14, 2022
1 parent 69675c3 commit 0cf8177
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 10 deletions.
15 changes: 5 additions & 10 deletions mmedit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,14 @@
from .comp1k_dataset import AdobeComp1kDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
from .paired_image_dataset import PairedImageDataset
from .singan_dataset import SinGANDataset
from .unpaired_image_dataset import UnpairedImageDataset

__all__ = [
'AdobeComp1kDataset',
'BasicImageDataset',
'BasicFramesDataset',
'BasicConditionalDataset',
'UnpairedImageDataset',
'PairedImageDataset',
'ImageNet',
'CIFAR10',
'GrowScaleImgDataset',
'SinGANDataset',
'AdobeComp1kDataset', 'BasicImageDataset', 'BasicFramesDataset',
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset'
]
101 changes: 101 additions & 0 deletions mmedit/datasets/mscoco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
from typing import Optional, Sequence, Union

import mmengine
from mmengine import FileClient

from mmedit.registry import DATASETS
from .basic_conditional_dataset import BasicConditionalDataset


@DATASETS.register_module()
@DATASETS.register_module('MSCOCO')
class MSCoCoDataset(BasicConditionalDataset):
"""MSCoCo 2014 dataset.
Args:
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
drop_caption_rate (float, optional): Rate of dropping caption,
used for training. Defaults to 0.0.
phase (str, optional): Subdataset used for certain phase, can be set
to `train`, `test` and `val`. Defaults to 'train'.
year (int, optional): Version of CoCo dataset, can be set to 2014
and 2017. Defaults to 2014.
data_prefix (str | dict): Prefix for the data. Defaults to ''.
extensions (Sequence[str]): A sequence of allowed extensions. Defaults
to ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif').
lazy_init (bool): Whether to load annotation during instantiation.
In some cases, such as visualization, only the meta information of
the dataset is needed, which is not necessary to load annotation
file. ``Basedataset`` can skip load annotations to save time by set
``lazy_init=False``. Defaults to False.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
METAINFO = dict(dataset_type='text_image_dataset', task_name='editing')

def __init__(self,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: str = '',
drop_caption_rate=0.0,
phase='train',
year=2014,
data_prefix: Union[str, dict] = '',
extensions: Sequence[str] = ('.jpg', '.jpeg', '.png', '.ppm',
'.bmp', '.pgm', '.tif'),
lazy_init: bool = False,
classes: Union[str, Sequence[str], None] = None,
**kwargs):
ann_file = os.path.join('annotations', 'captions_' + phase +
f'{year}.json') if ann_file == '' else ann_file
self.image_prename = 'COCO_' + phase + f'{year}_'
self.phase = phase
self.drop_rate = drop_caption_rate
self.year = year
assert self.year == 2014, 'We only support CoCo2014 now.'

super().__init__(
ann_file=ann_file,
metainfo=metainfo,
data_root=data_root,
data_prefix=data_prefix,
extensions=extensions,
lazy_init=lazy_init,
classes=classes,
**kwargs)

def load_data_list(self):
"""Load image paths and gt_labels."""
if self.img_prefix:
file_client = FileClient.infer_client(uri=self.img_prefix)
json_file = mmengine.fileio.io.load(self.ann_file)

def add_prefix(filename, prefix=''):
if not prefix:
return filename
else:
return file_client.join_path(prefix, filename)

data_list = []
for item in json_file['annotations']:
image_name = self.image_prename + str(
item['image_id']).zfill(12) + '.jpg'
img_path = add_prefix(
os.path.join(self.phase + str(self.year), image_name),
self.img_prefix)
caption = item['caption'].lower()
info = {
'img_path':
img_path,
'gt_label':
caption if (self.phase != 'train' or self.drop_rate < 1e-6
or random.random() >= self.drop_rate) else ''
}
data_list.append(info)
return data_list
3 changes: 3 additions & 0 deletions tests/data/coco/annotations/captions_train2014.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"annotations": [{"image_id": 9, "caption": "a good meal"}]
}
3 changes: 3 additions & 0 deletions tests/data/coco/annotations/captions_val2014.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"annotations": [{"image_id": 42, "caption": "a pair of slippers"}]
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 31 additions & 0 deletions tests/test_datasets/test_mscoco_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from pathlib import Path

from mmedit.datasets import MSCoCoDataset


class TestMSCoCoDatasets:

@classmethod
def setup_class(cls):
cls.data_root = Path(__file__).parent.parent / 'data' / 'coco'

def test_mscoco(self):

# test basic usage
dataset = MSCoCoDataset(data_root=self.data_root, pipeline=[])
assert dataset[0] == dict(
gt_label='a good meal',
img_path=os.path.join(self.data_root, 'train2014',
'COCO_train2014_000000000009.jpg'),
sample_idx=0)

# test with different phase
dataset = MSCoCoDataset(
data_root=self.data_root, phase='val', pipeline=[])
assert dataset[0] == dict(
gt_label='a pair of slippers',
img_path=os.path.join(self.data_root, 'val2014',
'COCO_val2014_000000000042.jpg'),
sample_idx=0)

0 comments on commit 0cf8177

Please sign in to comment.