Skip to content

Commit

Permalink
utilizing defaultdict, optimize single_selection check
Browse files Browse the repository at this point in the history
  • Loading branch information
bonhunko committed Oct 18, 2022
1 parent 192820f commit 2e117c5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
22 changes: 17 additions & 5 deletions datumaro/plugins/datumaro_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import os.path as osp
import shutil
from collections import defaultdict

import numpy as np
import pycocotools.mask as mask_utils
Expand Down Expand Up @@ -336,9 +337,13 @@ def apply(self):
for writer in writers.values():
writer.add_categories(self._extractor.categories())

single_selection_info = self._extruct_single_selection_info(self._extractor.categories())

for item in self._extractor:
subset = item.subset or DEFAULT_SUBSET_NAME
item = self._filterout_for_single_selection(item, self._extractor.categories())
item = self._filterout_for_single_selection(
item, self._extractor.categories(), single_selection_info
)
writers[subset].add_item(item)

for subset, writer in writers.items():
Expand Down Expand Up @@ -388,9 +393,9 @@ def patch(cls, dataset, patch, save_dir, **kwargs):
shutil.rmtree(related_images_path)

@staticmethod
def _filterout_for_single_selection(item, categories):
def _extruct_single_selection_info(categories):
if AnnotationType.label not in categories:
return item
return {}

name2parent_ss = {}
for label_category in categories[AnnotationType.label]:
Expand All @@ -399,8 +404,15 @@ def _filterout_for_single_selection(item, categories):
label_category.single_selection,
)

return name2parent_ss

@staticmethod
def _filterout_for_single_selection(item, categories, name2parent_ss):
if AnnotationType.label not in categories:
return item

# collect childrens that have the same parent which only allow single-selection
parent2ss_indices = {}
parent2ss_indices = defaultdict(list)
for i, annotation in enumerate(item.annotations):
if annotation._type != AnnotationType.label:
continue
Expand All @@ -414,7 +426,7 @@ def _filterout_for_single_selection(item, categories):
_, single_selection = name2parent_ss[parent]

if single_selection:
parent2ss_indices[parent] = parent2ss_indices.get(parent, []) + [i]
parent2ss_indices[parent] += [i]

# remove labels that dis-obey the single-selection rule
for indices in parent2ss_indices.values():
Expand Down
18 changes: 10 additions & 8 deletions tests/test_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: MIT
import os.path as osp
import tempfile
from unittest.case import TestCase

import numpy as np
Expand All @@ -13,8 +14,6 @@

from .requirements import Requirements, mark_requirement

TMP_DATASET_DIR = osp.join(osp.dirname(__file__), "tmp", "test_labeling")


class LabelingTest(TestCase):
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
Expand Down Expand Up @@ -79,8 +78,9 @@ def test_label_single_selection_filtered(self):
},
)

dataset.export(TMP_DATASET_DIR, format="datumaro")
dataset_imported = Dataset.import_from(TMP_DATASET_DIR, format="datumaro")
with tempfile.TemporaryDirectory() as temp_home:
dataset.export(temp_home, format="datumaro")
dataset_imported = Dataset.import_from(temp_home, format="datumaro")

for item in dataset_imported:
self.assertEqual(len(item.annotations), 0)
Expand Down Expand Up @@ -115,8 +115,9 @@ def test_label_single_selection_not_filtered(self):
},
)

dataset.export(TMP_DATASET_DIR, format="datumaro")
dataset_imported = Dataset.import_from(TMP_DATASET_DIR, format="datumaro")
with tempfile.TemporaryDirectory() as temp_home:
dataset.export(temp_home, format="datumaro")
dataset_imported = Dataset.import_from(temp_home, format="datumaro")

for item in dataset_imported:
self.assertEqual(len(item.annotations), 2)
Expand Down Expand Up @@ -147,8 +148,9 @@ def test_label_single_selection_correct(self):
},
)

dataset.export(TMP_DATASET_DIR, format="datumaro")
dataset_imported = Dataset.import_from(TMP_DATASET_DIR, format="datumaro")
with tempfile.TemporaryDirectory() as temp_home:
dataset.export(temp_home, format="datumaro")
dataset_imported = Dataset.import_from(temp_home, format="datumaro")

for item in dataset_imported:
self.assertEqual(len(item.annotations), 1)

0 comments on commit 2e117c5

Please sign in to comment.