Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datumaro] Add DatasetItem attributes #1639

Merged
merged 3 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def __eq__(self, other):
class DatasetItem:
# pylint: disable=redefined-builtin
def __init__(self, id=None, annotations=None,
subset=None, path=None, image=None):
subset=None, path=None, image=None, attributes=None):
assert id is not None
self._id = str(id)

Expand Down Expand Up @@ -604,6 +604,12 @@ def __init__(self, id=None, annotations=None,
image = Image(path=image)
assert image is None or isinstance(image, Image)
self._image = image

if attributes is None:
attributes = {}
else:
attributes = dict(attributes)
self._attributes = attributes
# pylint: enable=redefined-builtin

@property
Expand All @@ -630,6 +636,10 @@ def image(self):
def has_image(self):
return self._image is not None

@property
def attributes(self):
return self._attributes

def __eq__(self, other):
if not isinstance(other, __class__):
return False
Expand All @@ -638,10 +648,12 @@ def __eq__(self, other):
(self.subset == other.subset) and \
(self.path == other.path) and \
(self.annotations == other.annotations) and \
(self.image == other.image)
(self.image == other.image) and \
(self.attributes == other.attributes)

def wrap(item, **kwargs):
expected_args = {'id', 'annotations', 'subset', 'path', 'image'}
expected_args = {'id', 'annotations', 'subset',
'path', 'image', 'attributes'}
for k in expected_args:
if k not in kwargs:
kwargs[k] = getattr(item, k)
Expand Down
2 changes: 2 additions & 0 deletions datumaro/datumaro/plugins/datumaro_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def write_item(self, item):
'id': item.id,
'annotations': annotations,
}
if item.attributes:
item_desc['attr'] = item.attributes
if item.path:
item_desc['path'] = item.path
if item.has_image:
Expand Down
6 changes: 4 additions & 2 deletions datumaro/datumaro/plugins/datumaro_format/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,15 @@ def _load_items(self, parsed):
annotations = self._load_annotations(item_desc)

item = DatasetItem(id=item_id, subset=self._subset,
annotations=annotations, image=image)
annotations=annotations, image=image,
attributes=item_desc.get('attr'))

items.append(item)

return items

def _load_annotations(self, item):
@staticmethod
def _load_annotations(item):
parsed = item['annotations']
loaded = []

Expand Down
19 changes: 18 additions & 1 deletion datumaro/datumaro/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,21 @@ def compare_datasets(test, expected, actual):

ann_b = find(ann_b_matches, lambda x: x == ann_a)
test.assertEqual(ann_a, ann_b, 'ann: %s' % ann_to_str(ann_a))
item_b.annotations.remove(ann_b) # avoid repeats
item_b.annotations.remove(ann_b) # avoid repeats

def compare_datasets_strict(test, expected, actual):
# Compares datasets for strong equality

test.assertEqual(expected.categories(), actual.categories())

test.assertListEqual(sorted(expected.subsets()), sorted(actual.subsets()))
test.assertEqual(len(expected), len(actual))

for subset_name in expected.subsets():
e_subset = expected.get_subset(subset_name)
a_subset = actual.get_subset(subset_name)
test.assertEqual(len(e_subset), len(a_subset))
for idx, (item_a, item_b) in enumerate(zip(e_subset, a_subset)):
test.assertEqual(item_a, item_b,
'%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))
52 changes: 21 additions & 31 deletions datumaro/tests/test_datumaro_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,24 @@
from datumaro.plugins.datumaro_format.converter import DatumaroConverter
from datumaro.util.mask_tools import generate_colormap
from datumaro.util.image import Image
from datumaro.util.test_utils import TestDir, item_to_str

from datumaro.util.test_utils import TestDir, compare_datasets_strict

class DatumaroConverterTest(TestCase):
def _test_save_and_load(self, source_dataset, converter, test_dir,
target_dataset=None, importer_args=None):
converter(source_dataset, test_dir)

if importer_args is None:
importer_args = {}
parsed_dataset = Project.import_from(
test_dir, 'datumaro', **importer_args).make_dataset()

if target_dataset is None:
target_dataset = source_dataset

compare_datasets_strict(self,
expected=target_dataset, actual=parsed_dataset)

class TestExtractor(Extractor):
def __iter__(self):
return iter([
Expand Down Expand Up @@ -47,7 +61,8 @@ def __iter__(self):
Polygon([1, 2, 3, 4, 5, 6, 7, 8], id=12, z_order=4),
]),

DatasetItem(id=42, subset='test'),
DatasetItem(id=42, subset='test',
attributes={'a1': 5, 'a2': '42'}),

DatasetItem(id=42),
DatasetItem(id=43, image=Image(path='1/b/c.qq', size=(2, 4))),
Expand All @@ -73,36 +88,11 @@ def categories(self):

def test_can_save_and_load(self):
with TestDir() as test_dir:
source_dataset = self.TestExtractor()

converter = DatumaroConverter(save_images=True)
converter(source_dataset, test_dir)

project = Project.import_from(test_dir, 'datumaro')
parsed_dataset = project.make_dataset()

self.assertListEqual(
sorted(source_dataset.subsets()),
sorted(parsed_dataset.subsets()),
)

self.assertEqual(len(source_dataset), len(parsed_dataset))

for subset_name in source_dataset.subsets():
source_subset = source_dataset.get_subset(subset_name)
parsed_subset = parsed_dataset.get_subset(subset_name)
self.assertEqual(len(source_subset), len(parsed_subset))
for idx, (item_a, item_b) in enumerate(
zip(source_subset, parsed_subset)):
self.assertEqual(item_a, item_b, '%s:\n%s\nvs.\n%s\n' % \
(idx, item_to_str(item_a), item_to_str(item_b)))

self.assertEqual(
source_dataset.categories(),
parsed_dataset.categories())
self._test_save_and_load(self.TestExtractor(),
DatumaroConverter(save_images=True), test_dir)

def test_can_detect(self):
with TestDir() as test_dir:
DatumaroConverter()(self.TestExtractor(), save_dir=test_dir)

self.assertTrue(DatumaroImporter.detect(test_dir))
self.assertTrue(DatumaroImporter.detect(test_dir))