Skip to content

Commit

Permalink
[Datumaro] Introduce image info (#1140)
Browse files Browse the repository at this point in the history
* Employ transforms and item wrapper

* Add image class and tests

* Add image info support to formats

* Fix cli

* Fix merge and voc converte

* Update remote images extractor

* Codacy

* Remove item name, require path in Image

* Merge images of dataset items

* Update tests

* Add image dir converter

* Update Datumaro format

* Update COCO format with image info

* Update CVAT format with image info

* Update TFrecord format with image info

* Update VOC formar with image info

* Update YOLO format with image info

* Update dataset manager bindings with image info

* Add image name to id transform

* Fix coco export
  • Loading branch information
zhiltsov-max authored Feb 20, 2020
1 parent 0db48af commit a376ee7
Show file tree
Hide file tree
Showing 36 changed files with 848 additions and 487 deletions.
167 changes: 100 additions & 67 deletions cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from cvat.apps.engine.models import Task, ShapeType, AttributeType

import datumaro.components.extractor as datumaro
from datumaro.util.image import lazy_image
from datumaro.util.image import Image


class CvatImagesDirExtractor(datumaro.Extractor):
Expand All @@ -29,8 +29,7 @@ def __init__(self, url):
path = osp.join(dirpath, name)
if self._is_image(path):
item_id = Task.get_image_frame(path)
item = datumaro.DatasetItem(
id=item_id, image=lazy_image(path))
item = datumaro.DatasetItem(id=item_id, image=path)
items.append((item.id, item))

items = sorted(items, key=lambda e: int(e[0]))
Expand All @@ -49,112 +48,90 @@ def __len__(self):
def subsets(self):
return self._subsets

def get(self, item_id, subset=None, path=None):
if path or subset:
raise KeyError()
return self._items[item_id]

def _is_image(self, path):
for ext in self._SUPPORTED_FORMATS:
if osp.isfile(path) and path.endswith(ext):
return True
return False


class CvatTaskExtractor(datumaro.Extractor):
def __init__(self, url, db_task, user):
self._db_task = db_task
self._categories = self._load_categories()

cvat_annotations = TaskAnnotation(db_task.id, user)
with transaction.atomic():
cvat_annotations.init_from_db()
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
class CvatAnnotationsExtractor(datumaro.Extractor):
def __init__(self, url, cvat_annotations):
self._categories = self._load_categories(cvat_annotations)

dm_annotations = []

for cvat_anno in cvat_annotations.group_by_frame():
dm_anno = self._read_cvat_anno(cvat_anno)
dm_item = datumaro.DatasetItem(
id=cvat_anno.frame, annotations=dm_anno)
for cvat_frame_anno in cvat_annotations.group_by_frame():
dm_anno = self._read_cvat_anno(cvat_frame_anno, cvat_annotations)
dm_image = Image(path=cvat_frame_anno.name, size=(
cvat_frame_anno.height, cvat_frame_anno.width)
)
dm_item = datumaro.DatasetItem(id=cvat_frame_anno.frame,
annotations=dm_anno, image=dm_image)
dm_annotations.append((dm_item.id, dm_item))

dm_annotations = sorted(dm_annotations, key=lambda e: int(e[0]))
self._items = OrderedDict(dm_annotations)

self._subsets = None

def __iter__(self):
for item in self._items.values():
yield item

def __len__(self):
return len(self._items)

# pylint: disable=no-self-use
def subsets(self):
return self._subsets
return []
# pylint: enable=no-self-use

def get(self, item_id, subset=None, path=None):
if path or subset:
raise KeyError()
return self._items[item_id]
def categories(self):
return self._categories

def _load_categories(self):
@staticmethod
def _load_categories(cvat_anno):
categories = {}
label_categories = datumaro.LabelCategories()

db_labels = self._db_task.label_set.all()
for db_label in db_labels:
db_attributes = db_label.attributespec_set.all()
label_categories.add(db_label.name)

for db_attr in db_attributes:
label_categories.attributes.add(db_attr.name)
for _, label in cvat_anno.meta['task']['labels']:
label_categories.add(label['name'])
for _, attr in label['attributes']:
label_categories.attributes.add(attr['name'])

categories[datumaro.AnnotationType.label] = label_categories

return categories

def categories(self):
return self._categories

def _read_cvat_anno(self, cvat_anno):
def _read_cvat_anno(self, cvat_frame_anno, cvat_task_anno):
item_anno = []

categories = self.categories()
label_cat = categories[datumaro.AnnotationType.label]

label_map = {}
label_attrs = {}
db_labels = self._db_task.label_set.all()
for db_label in db_labels:
label_map[db_label.name] = label_cat.find(db_label.name)[0]

attrs = {}
db_attributes = db_label.attributespec_set.all()
for db_attr in db_attributes:
attrs[db_attr.name] = db_attr
label_attrs[db_label.name] = attrs
map_label = lambda label_db_name: label_map[label_db_name]
map_label = lambda name: label_cat.find(name)[0]
label_attrs = {
label['name']: label['attributes']
for _, label in cvat_task_anno.meta['task']['labels']
}

def convert_attrs(label, cvat_attrs):
cvat_attrs = {a.name: a.value for a in cvat_attrs}
dm_attr = dict()
for attr_name, attr_spec in label_attrs[label].items():
attr_value = cvat_attrs.get(attr_name, attr_spec.default_value)
for _, a_desc in label_attrs[label]:
a_name = a_desc['name']
a_value = cvat_attrs.get(a_name, a_desc['default_value'])
try:
if attr_spec.input_type == AttributeType.NUMBER:
attr_value = float(attr_value)
elif attr_spec.input_type == AttributeType.CHECKBOX:
attr_value = attr_value.lower() == 'true'
dm_attr[attr_name] = attr_value
if a_desc['input_type'] == AttributeType.NUMBER:
a_value = float(a_value)
elif a_desc['input_type'] == AttributeType.CHECKBOX:
a_value = (a_value.lower() == 'true')
dm_attr[a_name] = a_value
except Exception as e:
slogger.task[self._db_task.id].error(
"Failed to convert attribute '%s'='%s': %s" % \
(attr_name, attr_value, e))
raise Exception(
"Failed to convert attribute '%s'='%s': %s" %
(a_name, a_value, e))
return dm_attr

for tag_obj in cvat_anno.tags:
for tag_obj in cvat_frame_anno.tags:
anno_group = tag_obj.group
anno_label = map_label(tag_obj.label)
anno_attr = convert_attrs(tag_obj.label, tag_obj.attributes)
Expand All @@ -163,7 +140,7 @@ def convert_attrs(label, cvat_attrs):
attributes=anno_attr, group=anno_group)
item_anno.append(anno)

for shape_obj in cvat_anno.labeled_shapes:
for shape_obj in cvat_frame_anno.labeled_shapes:
anno_group = shape_obj.group
anno_label = map_label(shape_obj.label)
anno_attr = convert_attrs(shape_obj.label, shape_obj.attributes)
Expand All @@ -183,8 +160,64 @@ def convert_attrs(label, cvat_attrs):
anno = datumaro.Bbox(x0, y0, x1 - x0, y1 - y0,
label=anno_label, attributes=anno_attr, group=anno_group)
else:
raise Exception("Unknown shape type '%s'" % (shape_obj.type))
raise Exception("Unknown shape type '%s'" % shape_obj.type)

item_anno.append(anno)

return item_anno
return item_anno


class CvatTaskExtractor(CvatAnnotationsExtractor):
def __init__(self, url, db_task, user):
cvat_annotations = TaskAnnotation(db_task.id, user)
with transaction.atomic():
cvat_annotations.init_from_db()
cvat_annotations = Annotation(cvat_annotations.ir_data, db_task)
super().__init__(url, cvat_annotations)


def match_frame(item, cvat_task_anno):
frame_number = None
if frame_number is None:
try:
frame_number = cvat_task_anno.match_frame(item.id)
except Exception:
pass
if frame_number is None and item.has_image:
try:
frame_number = cvat_task_anno.match_frame(item.image.filename)
except Exception:
pass
if frame_number is None:
try:
frame_number = int(item.id)
except Exception:
pass
if not frame_number in cvat_task_anno.frame_info:
raise Exception("Could not match item id: '%s' with any task frame" %
item.id)
return frame_number

def import_dm_annotations(dm_dataset, cvat_task_anno):
shapes = {
datumaro.AnnotationType.bbox: ShapeType.RECTANGLE,
datumaro.AnnotationType.polygon: ShapeType.POLYGON,
datumaro.AnnotationType.polyline: ShapeType.POLYLINE,
datumaro.AnnotationType.points: ShapeType.POINTS,
}

label_cat = dm_dataset.categories()[datumaro.AnnotationType.label]

for item in dm_dataset:
frame_number = match_frame(item, cvat_task_anno)

for ann in item.annotations:
if ann.type in shapes:
cvat_task_anno.add_shape(cvat_task_anno.LabeledShape(
type=shapes[ann.type],
frame=frame_number,
label=label_cat.items[ann.label].name,
points=ann.points,
occluded=False,
attributes=[],
))
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SchemaBuilder as _SchemaBuilder,
)
import datumaro.components.extractor as datumaro
from datumaro.util.image import lazy_image, load_image
from datumaro.util.image import lazy_image, load_image, Image

from cvat.utils.cli.core import CLI as CVAT_CLI, CVAT_API_V1

Expand Down Expand Up @@ -103,8 +103,11 @@ def __init__(self, url):
items = []
for entry in image_list:
item_id = entry['id']
item = datumaro.DatasetItem(
id=item_id, image=self._make_image_loader(item_id))
size = None
if entry.get('height') and entry.get('width'):
size = (entry['height'], entry['width'])
image = Image(data=self._make_image_loader(item_id), size=size)
item = datumaro.DatasetItem(id=item_id, image=image)
items.append((item.id, item))

items = sorted(items, key=lambda e: int(e[0]))
Expand Down
21 changes: 12 additions & 9 deletions datumaro/datumaro/cli/contexts/project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,17 @@ def import_command(args):
if project_name is None:
project_name = osp.basename(project_dir)

extra_args = {}
try:
env = Environment()
importer = env.make_importer(args.format)
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)
except KeyError:
raise CliException("Importer for format '%s' is not found" % \
args.format)

extra_args = {}
if hasattr(importer, 'from_cmdline'):
extra_args = importer.from_cmdline(args.extra_args)

log.info("Importing project from '%s' as '%s'" % \
(args.source, args.format))

Expand Down Expand Up @@ -293,13 +294,14 @@ def export_command(args):

try:
converter = project.env.converters.get(args.format)
if hasattr(converter, 'from_cmdline'):
extra_args = converter.from_cmdline(args.extra_args)
converter = converter(**extra_args)
except KeyError:
raise CliException("Converter for format '%s' is not found" % \
args.format)

if hasattr(converter, 'from_cmdline'):
extra_args = converter.from_cmdline(args.extra_args)
converter = converter(**extra_args)

filter_args = FilterModes.make_filter_args(args.filter_mode)

log.info("Loading the project...")
Expand Down Expand Up @@ -559,14 +561,15 @@ def transform_command(args):
(project.config.project_name, make_file_name(args.transform)))
dst_dir = osp.abspath(dst_dir)

extra_args = {}
try:
transform = project.env.transforms.get(args.transform)
if hasattr(transform, 'from_cmdline'):
extra_args = transform.from_cmdline(args.extra_args)
except KeyError:
raise CliException("Transform '%s' is not found" % args.transform)

extra_args = {}
if hasattr(transform, 'from_cmdline'):
extra_args = transform.from_cmdline(args.extra_args)

log.info("Loading the project...")
dataset = project.make_dataset()

Expand Down
Loading

0 comments on commit a376ee7

Please sign in to comment.