forked from cvat-ai/cvat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add VGGFace2 format support (cvat-ai#69)
* Add VGGFace2 format support Co-authored-by: Maxim Zhiltsov <maxim.zhiltsov@intel.com>
- Loading branch information
1 parent
77fdd4d
commit 893dd96
Showing
9 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import csv | ||
import os | ||
import os.path as osp | ||
from glob import glob | ||
|
||
from datumaro.components.converter import Converter | ||
from datumaro.components.extractor import (AnnotationType, Bbox, DatasetItem, | ||
Importer, Points, LabelCategories, SourceExtractor) | ||
|
||
|
||
class VggFace2Path: | ||
ANNOTATION_DIR = "bb_landmark" | ||
IMAGE_EXT = '.jpg' | ||
BBOXES_FILE = 'loose_bb_' | ||
LANDMARKS_FILE = 'loose_landmark_' | ||
|
||
class VggFace2Extractor(SourceExtractor): | ||
def __init__(self, path): | ||
if not osp.isfile(path): | ||
raise Exception("Can't read .csv annotation file '%s'" % path) | ||
self._path = path | ||
self._dataset_dir = osp.dirname(osp.dirname(path)) | ||
|
||
subset = osp.splitext(osp.basename(path))[0] | ||
if subset.startswith(VggFace2Path.LANDMARKS_FILE): | ||
subset = subset.split('_')[2] | ||
super().__init__(subset=subset) | ||
|
||
self._load_categories() | ||
self._items = list(self._load_items(path).values()) | ||
|
||
def _load_categories(self): | ||
self._categories[AnnotationType.label] = LabelCategories() | ||
|
||
def _load_items(self, path): | ||
items = {} | ||
with open(path) as content: | ||
landmarks_table = list(csv.DictReader(content)) | ||
|
||
for row in landmarks_table: | ||
item_id = row['NAME_ID'] | ||
image_path = osp.join(self._dataset_dir, self._subset, | ||
item_id + VggFace2Path.IMAGE_EXT) | ||
annotations = [] | ||
if len([p for p in row if row[p] == '']) == 0 and len(row) == 11: | ||
annotations.append(Points( | ||
[float(row[p]) for p in row if p != 'NAME_ID'])) | ||
if item_id in items and 0 < len(annotations): | ||
annotation = items[item_id].annotations | ||
annotation.append(annotations[0]) | ||
else: | ||
items[item_id] = DatasetItem(id=item_id, subset=self._subset, | ||
image=image_path, annotations=annotations) | ||
|
||
bboxes_path = osp.join(self._dataset_dir, VggFace2Path.ANNOTATION_DIR, | ||
VggFace2Path.BBOXES_FILE + self._subset + '.csv') | ||
if osp.isfile(bboxes_path): | ||
with open(bboxes_path) as content: | ||
bboxes_table = list(csv.DictReader(content)) | ||
for row in bboxes_table: | ||
if len([p for p in row if row[p] == '']) == 0 and len(row) == 5: | ||
item_id = row['NAME_ID'] | ||
annotations = items[item_id].annotations | ||
annotations.append(Bbox(int(row['X']), int(row['Y']), | ||
int(row['W']), int(row['H']))) | ||
return items | ||
|
||
class VggFace2Importer(Importer): | ||
@classmethod | ||
def find_sources(cls, path): | ||
subset_paths = [p for p in glob(osp.join(path, | ||
VggFace2Path.ANNOTATION_DIR, '**.csv'), recursive=True) | ||
if not osp.basename(p).startswith(VggFace2Path.BBOXES_FILE)] | ||
sources = [] | ||
for subset_path in subset_paths: | ||
sources += cls._find_sources_recursive( | ||
subset_path, '.csv', 'vgg_face2') | ||
return sources | ||
|
||
class VggFace2Converter(Converter): | ||
DEFAULT_IMAGE_EXT = '.jpg' | ||
|
||
def apply(self): | ||
save_dir = self._save_dir | ||
|
||
os.makedirs(save_dir, exist_ok=True) | ||
for subset_name, subset in self._extractor.subsets().items(): | ||
subset_dir = osp.join(save_dir, subset_name) | ||
bboxes_table = [] | ||
landmarks_table = [] | ||
for item in subset: | ||
if item.has_image and self._save_images: | ||
self._save_image(item, osp.join(save_dir, subset_dir, | ||
item.id + VggFace2Path.IMAGE_EXT)) | ||
|
||
landmarks = [a for a in item.annotations | ||
if a.type == AnnotationType.points] | ||
if landmarks: | ||
for landmark in landmarks: | ||
points = landmark.points | ||
landmarks_table.append({'NAME_ID': item.id, | ||
'P1X': points[0], 'P1Y': points[1], | ||
'P2X': points[2], 'P2Y': points[3], | ||
'P3X': points[4], 'P3Y': points[5], | ||
'P4X': points[6], 'P4Y': points[7], | ||
'P5X': points[8], 'P5Y': points[9]}) | ||
else: | ||
landmarks_table.append({'NAME_ID': item.id}) | ||
|
||
bboxes = [a for a in item.annotations | ||
if a.type == AnnotationType.bbox] | ||
if bboxes: | ||
for bbox in bboxes: | ||
bboxes_table.append({'NAME_ID': item.id, 'X': int(bbox.x), | ||
'Y': int(bbox.y), 'W': int(bbox.w), 'H': int(bbox.h)}) | ||
|
||
landmarks_path = osp.join(save_dir, VggFace2Path.ANNOTATION_DIR, | ||
VggFace2Path.LANDMARKS_FILE + subset_name + '.csv') | ||
os.makedirs(osp.dirname(landmarks_path), exist_ok=True) | ||
with open(landmarks_path, 'w', newline='') as file: | ||
columns = ['NAME_ID', 'P1X', 'P1Y', 'P2X', 'P2Y', | ||
'P3X', 'P3Y', 'P4X', 'P4Y', 'P5X', 'P5Y'] | ||
writer = csv.DictWriter(file, fieldnames=columns) | ||
writer.writeheader() | ||
writer.writerows(landmarks_table) | ||
|
||
if bboxes_table: | ||
bboxes_path = osp.join(save_dir, VggFace2Path.ANNOTATION_DIR, | ||
VggFace2Path.BBOXES_FILE + subset_name + '.csv') | ||
os.makedirs(osp.dirname(bboxes_path), exist_ok=True) | ||
with open(bboxes_path, 'w', newline='') as file: | ||
columns = ['NAME_ID', 'X', 'Y', 'W', 'H'] | ||
writer = csv.DictWriter(file, fieldnames=columns) | ||
writer.writeheader() | ||
writer.writerows(bboxes_table) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
tests/assets/vgg_face2_dataset/bb_landmark/loose_bb_train.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
NAME_ID,X,Y,W,H | ||
n000001/0001_01,2,2,1,2 | ||
n000002/0002_01,1,3,1,1 |
3 changes: 3 additions & 0 deletions
3
tests/assets/vgg_face2_dataset/bb_landmark/loose_landmark_train.csv
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
NAME_ID,P1X,P1Y,P2X,P2Y,P3X,P3Y,P4X,P4Y,P5X,P5Y | ||
n000001/0001_01,2.787,2.898,2.965,2.79,2.8,2.456,2.81,2.32,2.89,2.3 | ||
n000002/0002_01,1.2,3.8,1.8,3.82,1.51,3.634,1.43,3.34,1.65,3.32 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import os.path as osp | ||
from unittest import TestCase | ||
|
||
import numpy as np | ||
from datumaro.components.extractor import Bbox, DatasetItem, Points | ||
from datumaro.components.project import Dataset, Project | ||
from datumaro.plugins.vgg_face2_format import (VggFace2Converter, | ||
VggFace2Importer) | ||
from datumaro.util.test_utils import TestDir, compare_datasets | ||
|
||
|
||
class VggFace2FormatTest(TestCase): | ||
def test_can_save_and_load(self): | ||
source_dataset = Dataset.from_iterable([ | ||
DatasetItem(id='1', subset='train', image=np.ones((8, 8, 3)), | ||
annotations=[ | ||
Bbox(0, 2, 4, 2), | ||
Points([3.2, 3.12, 4.11, 3.2, 2.11, | ||
2.5, 3.5, 2.11, 3.8, 2.13]), | ||
] | ||
), | ||
DatasetItem(id='2', subset='train', image=np.ones((10, 10, 3)), | ||
annotations=[ | ||
Points([4.23, 4.32, 5.34, 4.45, 3.54, | ||
3.56, 4.52, 3.51, 4.78, 3.34]), | ||
] | ||
), | ||
DatasetItem(id='3', subset='val', image=np.ones((8, 8, 3))), | ||
DatasetItem(id='4', subset='val', image=np.ones((10, 10, 3)), | ||
annotations=[ | ||
Bbox(0, 2, 4, 2), | ||
Points([3.2, 3.12, 4.11, 3.2, 2.11, | ||
2.5, 3.5, 2.11, 3.8, 2.13]), | ||
Bbox(2, 2, 1, 2), | ||
Points([2.787, 2.898, 2.965, 2.79, 2.8, | ||
2.456, 2.81, 2.32, 2.89, 2.3]), | ||
] | ||
), | ||
DatasetItem(id='5', subset='val', image=np.ones((8, 8, 3)), | ||
annotations=[ | ||
Bbox(2, 2, 2, 2), | ||
] | ||
), | ||
], categories=[]) | ||
|
||
with TestDir() as test_dir: | ||
VggFace2Converter.convert(source_dataset, test_dir, save_images=True) | ||
parsed_dataset = VggFace2Importer()(test_dir).make_dataset() | ||
|
||
compare_datasets(self, source_dataset, parsed_dataset) | ||
|
||
def test_can_save_dataset_with_no_subsets(self): | ||
source_dataset = Dataset.from_iterable([ | ||
DatasetItem(id='a/b/1', image=np.ones((8, 8, 3)), | ||
annotations=[ | ||
Bbox(0, 2, 4, 2), | ||
Points([4.23, 4.32, 5.34, 4.45, 3.54, | ||
3.56, 4.52, 3.51, 4.78, 3.34]), | ||
] | ||
), | ||
], categories=[]) | ||
|
||
with TestDir() as test_dir: | ||
VggFace2Converter.convert(source_dataset, test_dir, save_images=True) | ||
parsed_dataset = VggFace2Importer()(test_dir).make_dataset() | ||
|
||
compare_datasets(self, source_dataset, parsed_dataset) | ||
|
||
|
||
DUMMY_DATASET_DIR = osp.join(osp.dirname(__file__), 'assets', 'vgg_face2_dataset') | ||
|
||
class VggFace2ImporterTest(TestCase): | ||
def test_can_detect(self): | ||
self.assertTrue(VggFace2Importer.detect(DUMMY_DATASET_DIR)) | ||
|
||
def test_can_import(self): | ||
expected_dataset = Dataset.from_iterable([ | ||
DatasetItem(id='n000001/0001_01', subset='train', | ||
image=np.ones((10, 15, 3)), | ||
annotations=[ | ||
Bbox(2, 2, 1, 2), | ||
Points([2.787, 2.898, 2.965, 2.79, 2.8, | ||
2.456, 2.81, 2.32, 2.89, 2.3]), | ||
] | ||
), | ||
DatasetItem(id='n000002/0002_01', subset='train', | ||
image=np.ones((10, 15, 3)), | ||
annotations=[ | ||
Bbox(1, 3, 1, 1), | ||
Points([1.2, 3.8, 1.8, 3.82, 1.51, | ||
3.634, 1.43, 3.34, 1.65, 3.32]) | ||
] | ||
), | ||
], categories=[]) | ||
|
||
dataset = Project.import_from(DUMMY_DATASET_DIR, 'vgg_face2') \ | ||
.make_dataset() | ||
|
||
compare_datasets(self, expected_dataset, dataset) |