diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fe6a63612..da859b025e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Tensorflow AVX check is made optional in API and is disabled by default () +- Extensions for images in ImageNet_txt are now mandatory () ### Deprecated - diff --git a/datumaro/plugins/imagenet_txt_format.py b/datumaro/plugins/imagenet_txt_format.py index 3a1578431d..d50148956f 100644 --- a/datumaro/plugins/imagenet_txt_format.py +++ b/datumaro/plugins/imagenet_txt_format.py @@ -10,7 +10,6 @@ LabelCategories, AnnotationType, SourceExtractor, Importer ) from datumaro.components.converter import Converter -from datumaro.util.image import find_images class ImagenetTxtPath: @@ -49,26 +48,22 @@ def _load_categories(self, labels): def _load_items(self, path): items = {} - image_dir = self.image_dir - if osp.isdir(image_dir): - images = { osp.splitext(osp.relpath(p, image_dir))[0]: p - for p in find_images(image_dir, recursive=True) } - else: - images = {} - with open(path, encoding='utf-8') as f: for line in f: item = line.split('\"') if 1 < len(item): if len(item) == 3: item_id = item[1] - label_ids = [int(id) for id in item[2].split()] + item = item[2].split() + image = item_id + item[0] + label_ids = [int(id) for id in item[1:]] else: raise Exception("Line %s: unexpected number " "of quotes in filename" % line) else: item = line.split() - item_id = item[0] + item_id = osp.splitext(item[0])[0] + image = item[0] label_ids = [int(id) for id in item[1:]] anno = [] @@ -79,7 +74,7 @@ def _load_items(self, path): anno.append(Label(label)) items[item_id] = DatasetItem(id=item_id, subset=self._subset, - image=images.get(item_id), annotations=anno) + image=osp.join(self.image_dir, image), annotations=anno) return items @@ -105,7 +100,11 @@ def apply(self): labels = {} for item in subset: - labels[item.id] = set(p.label for p in item.annotations + item_id = item.id + if 1 < len(item_id.split()): + item_id = '\"' + item_id + '\"' + item_id += self._find_image_ext(item) + labels[item_id] = set(p.label for p in item.annotations if p.type == AnnotationType.label) if self._save_images and item.has_image: @@ -113,10 +112,8 @@ def apply(self): annotation = '' for item_id, item_labels in labels.items(): - if 1 < len(item_id.split()): - item_id = '\"' + item_id + '\"' - annotation += '%s %s\n' % ( - item_id, ' '.join(str(l) for l in item_labels)) + annotation += '%s %s\n' % (item_id, + ' '.join(str(l) for l in item_labels)) with open(annotation_file, 'w', encoding='utf-8') as f: f.write(annotation) diff --git a/tests/assets/imagenet_txt_dataset/train.txt b/tests/assets/imagenet_txt_dataset/train.txt index 624d111346..e7b972634f 100644 --- a/tests/assets/imagenet_txt_dataset/train.txt +++ b/tests/assets/imagenet_txt_dataset/train.txt @@ -1,4 +1,4 @@ -1 0 -2 5 -3 3 -4 5 \ No newline at end of file +1.jpg 0 +2.jpg 5 +3.jpg 3 +4.jpg 5