Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Add ignore_keys in ConcatDataset #556

Merged
merged 13 commits into from
Nov 1, 2022
34 changes: 29 additions & 5 deletions mmengine/dataset/dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ class ConcatDataset(_ConcatDataset):
which will be concatenated.
lazy_init (bool, optional): Whether to load annotation during
instantiation. Defaults to False.
ignore_keys (List[str] or str): Ignore the keys that can be
unequal in `dataset.metainfo`. Defaults to None.
`New in version 0.3.0.`
"""

def __init__(self,
datasets: Sequence[Union[BaseDataset, dict]],
lazy_init: bool = False):
lazy_init: bool = False,
ignore_keys: Union[str, List[str], None] = None):
self.datasets: List[BaseDataset] = []
for i, dataset in enumerate(datasets):
if isinstance(dataset, dict):
Expand All @@ -45,13 +49,33 @@ def __init__(self,
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
if ignore_keys is None:
self.ignore_keys = []
elif isinstance(ignore_keys, str):
self.ignore_keys = [ignore_keys]
elif isinstance(ignore_keys, list):
self.ignore_keys = ignore_keys
else:
raise TypeError('ignore_keys should be a list or str, '
f'but got {type(ignore_keys)}')

meta_keys: set = set()
for dataset in self.datasets:
meta_keys |= dataset.metainfo.keys()
# Only use metainfo of first dataset.
self._metainfo = self.datasets[0].metainfo
for i, dataset in enumerate(self.datasets, 1):
if self._metainfo != dataset.metainfo:
raise ValueError(
f'The meta information of the {i}-th dataset does not '
'match meta information of the first dataset')
for key in meta_keys:
if key in self.ignore_keys:
continue
if key not in dataset.metainfo:
raise ValueError(
f'{key} does not in the meta information of '
f'the {i}-th dataset')
if self._metainfo[key] != dataset.metainfo[key]:
raise ValueError(
f'The meta information of the {i}-th dataset does not '
'match meta information of the first dataset')

self._fully_initialized = False
if not lazy_init:
Expand Down
35 changes: 29 additions & 6 deletions tests/test_dataset/test_base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,10 @@ def test_init(self):
with pytest.raises(TypeError):
ConcatDataset(datasets=[0])

with pytest.raises(TypeError):
ConcatDataset(
datasets=[self.dataset_a, dataset_cfg_b], ignore_keys=1)

def test_full_init(self):
# test init with lazy_init=True
self.cat_datasets.full_init()
Expand All @@ -654,14 +658,33 @@ def test_full_init(self):

with pytest.raises(NotImplementedError):
self.cat_datasets.get_subset(1)
# Different meta information will raise error.

dataset_b = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=dict(classes=('cat')))
# Regardless of order, different meta information without
# `ignore_keys` will raise error.
with pytest.raises(ValueError):
dataset_b = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json',
metainfo=dict(classes=('cat')))
ConcatDataset(datasets=[self.dataset_a, dataset_b])
with pytest.raises(ValueError):
ConcatDataset(datasets=[dataset_b, self.dataset_a])
# `ignore_keys` does not contain different meta information keys will
# raise error.
with pytest.raises(ValueError):
ConcatDataset(
datasets=[self.dataset_a, dataset_b], ignore_keys=['a'])
# Different meta information with `ignore_keys` will not raise error.
cat_datasets = ConcatDataset(
datasets=[self.dataset_a, dataset_b], ignore_keys='classes')
cat_datasets.full_init()
assert len(cat_datasets) == 6
cat_datasets.full_init()
cat_datasets._fully_initialized = False
cat_datasets[1]
assert len(cat_datasets.metainfo) == 3
assert len(cat_datasets) == 6

def test_metainfo(self):
assert self.cat_datasets.metainfo == self.dataset_a.metainfo
Expand Down