Skip to content

Commit

Permalink
[Datumaro] Add merge command with segment intersection (#1695)
Browse files Browse the repository at this point in the history
* Add multi source merge

* update changelog

* cli update

* linter

* fixes and tests

* fix test

* fix test

* relax type requirements in annotations

* fix polylines

* Make groups more stable

* Add group checks

* add group check test
  • Loading branch information
zhiltsov-max authored Aug 17, 2020
1 parent 90cc36e commit 17a5554
Show file tree
Hide file tree
Showing 20 changed files with 1,438 additions and 134 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Siammask tracker as DL serverless function (<https://github.com/opencv/cvat/pull/1988>)
- [Datumaro] Added model info and source info commands (<https://github.com/opencv/cvat/pull/1973>)
- [Datumaro] Dataset statistics (<https://github.com/opencv/cvat/pull/1668>)
- [Datumaro] Multi-dataset merge (https://github.com/opencv/cvat/pull/1695)

### Changed
- Shape coordinates are rounded to 2 digits in dumped annotations (<https://github.com/opencv/cvat/pull/1970>)
Expand Down
1 change: 1 addition & 0 deletions datumaro/datumaro/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def make_parser():
('remove', commands.remove, "Remove source from project"),
('export', commands.export, "Export project"),
('explain', commands.explain, "Run Explainable AI algorithm for model"),
('merge', commands.merge, "Merge datasets"),
('convert', commands.convert, "Convert dataset"),
]

Expand Down
2 changes: 1 addition & 1 deletion datumaro/datumaro/cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
#
# SPDX-License-Identifier: MIT

from . import add, create, explain, export, remove, convert
from . import add, create, explain, export, remove, merge, convert
124 changes: 124 additions & 0 deletions datumaro/datumaro/cli/commands/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@

# Copyright (C) 2020 Intel Corporation
#
# SPDX-License-Identifier: MIT

import argparse
import json
import logging as log
import os.path as osp
from collections import OrderedDict

from datumaro.components.project import Project
from datumaro.components.operations import (IntersectMerge,
QualityError, MergeError)

from ..util import at_least, MultilineFormatter, CliException
from ..util.project import generate_next_file_name, load_project


def build_parser(parser_ctor=argparse.ArgumentParser):
parser = parser_ctor(help="Merge few projects",
description="""
Merges multiple datasets into one. This can be useful if you
have few annotations and wish to merge them,
taking into consideration potential overlaps and conflicts.
This command can try to find a common ground by voting or
return a list of conflicts.|n
|n
Examples:|n
- Merge annotations from 3 (or more) annotators:|n
|s|smerge project1/ project2/ project3/|n
- Check groups of the merged dataset for consistence:|n
|s|s|slook for groups consising of 'person', 'hand' 'head', 'foot'|n
|s|smerge project1/ project2/ -g 'person,hand?,head,foot?'
""",
formatter_class=MultilineFormatter)

def _group(s):
return s.split(',')

parser.add_argument('project', nargs='+', action=at_least(2),
help="Path to a project (repeatable)")
parser.add_argument('-iou', '--iou-thresh', default=0.25, type=float,
help="IoU match threshold for segments (default: %(default)s)")
parser.add_argument('-oconf', '--output-conf-thresh',
default=0.0, type=float,
help="Confidence threshold for output "
"annotations (default: %(default)s)")
parser.add_argument('--quorum', default=0, type=int,
help="Minimum count for a label and attribute voting "
"results to be counted (default: %(default)s)")
parser.add_argument('-g', '--groups', action='append', type=_group,
default=[],
help="A comma-separated list of labels in "
"annotation groups to check. '?' postfix can be added to a label to"
"make it optional in the group (repeatable)")
parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None,
help="Output directory (default: current project's dir)")
parser.add_argument('--overwrite', action='store_true',
help="Overwrite existing files in the save directory")
parser.set_defaults(command=merge_command)

return parser

def merge_command(args):
source_projects = [load_project(p) for p in args.project]

dst_dir = args.dst_dir
if dst_dir:
if not args.overwrite and osp.isdir(dst_dir) and os.listdir(dst_dir):
raise CliException("Directory '%s' already exists "
"(pass --overwrite to overwrite)" % dst_dir)
else:
dst_dir = generate_next_file_name('merged')

source_datasets = []
for p in source_projects:
log.debug("Loading project '%s' dataset", p.config.project_name)
source_datasets.append(p.make_dataset())

merger = IntersectMerge(conf=IntersectMerge.Conf(
pairwise_dist=args.iou_thresh, groups=args.groups,
output_conf_thresh=args.output_conf_thresh, quorum=args.quorum
))
merged_dataset = merger(source_datasets)

merged_project = Project()
output_dataset = merged_project.make_dataset()
output_dataset.define_categories(merged_dataset.categories())
merged_dataset = output_dataset.update(merged_dataset)
merged_dataset.save(save_dir=dst_dir)

report_path = osp.join(dst_dir, 'merge_report.json')
save_merge_report(merger, report_path)

dst_dir = osp.abspath(dst_dir)
log.info("Merge results have been saved to '%s'" % dst_dir)
log.info("Report has been saved to '%s'" % report_path)

return 0

def save_merge_report(merger, path):
item_errors = OrderedDict()
source_errors = OrderedDict()
all_errors = []

for e in merger.errors:
if isinstance(e, QualityError):
item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1
elif isinstance(e, MergeError):
for s in e.sources:
source_errors[s] = source_errors.get(s, 0) + 1
item_errors[str(e.item_id)] = item_errors.get(str(e.item_id), 0) + 1

all_errors.append(str(e))

errors = OrderedDict([
('Item errors', item_errors),
('Source errors', source_errors),
('All errors', all_errors),
])

with open(path, 'w') as f:
json.dump(errors, f, indent=4)
22 changes: 22 additions & 0 deletions datumaro/datumaro/cli/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ def _fill_text(self, text, width, indent):
multiline_text += formatted_paragraph
return multiline_text

def required_count(nmin=0, nmax=0):
assert 0 <= nmin and 0 <= nmax and nmin or nmax

class RequiredCount(argparse.Action):
def __call__(self, parser, args, values, option_string=None):
k = len(values)
if not ((nmin and (nmin <= k) or not nmin) and \
(nmax and (k <= nmax) or not nmax)):
msg = "Argument '%s' requires" % self.dest
if nmin and nmax:
msg += " from %s to %s arguments" % (nmin, nmax)
elif nmin:
msg += " at least %s arguments" % nmin
else:
msg += " no more %s arguments" % nmax
raise argparse.ArgumentTypeError(msg)
setattr(args, self.dest, values)
return RequiredCount

def at_least(n):
return required_count(n, 0)

def make_file_name(s):
# adapted from
# https://docs.djangoproject.com/en/2.1/_modules/django/utils/text/#slugify
Expand Down
25 changes: 4 additions & 21 deletions datumaro/datumaro/components/algorithms/rise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from math import ceil

from datumaro.components.extractor import AnnotationType
from datumaro.util.annotation_util import nms


def flatmatvec(mat):
Expand Down Expand Up @@ -51,24 +52,6 @@ def split_outputs(annotations):
bboxes.append(r)
return labels, bboxes

@staticmethod
def nms(boxes, iou_thresh=0.5):
indices = np.argsort([b.attributes['score'] for b in boxes])
ious = np.array([[a.iou(b) for b in boxes] for a in boxes])

predictions = []
while len(indices) != 0:
i = len(indices) - 1
pred_idx = indices[i]
to_remove = [i]
predictions.append(boxes[pred_idx])
for i, box_idx in enumerate(indices[:i]):
if iou_thresh < ious[pred_idx, box_idx]:
to_remove.append(i)
indices = np.delete(indices, to_remove)

return predictions

def normalize_hmaps(self, heatmaps, counts):
eps = np.finfo(heatmaps.dtype).eps
mhmaps = flatmatvec(heatmaps)
Expand Down Expand Up @@ -106,7 +89,7 @@ def apply(self, image, progressive=False):
result_bboxes = [b for b in result_bboxes \
if self.det_conf_thresh <= b.attributes['score']]
if 0 < self.nms_thresh:
result_bboxes = self.nms(result_bboxes, self.nms_thresh)
result_bboxes = nms(result_bboxes, self.nms_thresh)

predicted_labels = set()
if len(result_labels) != 0:
Expand Down Expand Up @@ -194,15 +177,15 @@ def apply(self, image, progressive=False):
result_bboxes = [b for b in result_bboxes \
if self.det_conf_thresh <= b.attributes['score']]
if 0 < self.nms_thresh:
result_bboxes = self.nms(result_bboxes, self.nms_thresh)
result_bboxes = nms(result_bboxes, self.nms_thresh)

for detection in result_bboxes:
for pred_idx, pred in enumerate(predicted_bboxes):
if pred.label != detection.label:
continue

iou = pred.iou(detection)
assert 0 <= iou and iou <= 1
assert iou == -1 or 0 <= iou and iou <= 1
if iou < iou_thresh:
continue

Expand Down
36 changes: 10 additions & 26 deletions datumaro/datumaro/components/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datumaro.util.image import Image
from datumaro.util.attrs_util import not_empty, default_if_none


AnnotationType = Enum('AnnotationType',
[
'label',
Expand All @@ -28,9 +29,9 @@

@attrs
class Annotation:
id = attrib(converter=int, default=0, kw_only=True)
attributes = attrib(converter=dict, factory=dict, kw_only=True)
group = attrib(converter=int, default=0, kw_only=True)
id = attrib(default=0, validator=default_if_none(int), kw_only=True)
attributes = attrib(factory=dict, validator=default_if_none(dict), kw_only=True)
group = attrib(default=0, validator=default_if_none(int), kw_only=True)

def __attrs_post_init__(self):
assert isinstance(self.type, AnnotationType)
Expand Down Expand Up @@ -92,7 +93,7 @@ def _reindex(self):
self._indices = indices

def add(self, name, parent=None, attributes=None):
assert name not in self._indices
assert name not in self._indices, name
if attributes is None:
attributes = set()
else:
Expand All @@ -110,7 +111,7 @@ def add(self, name, parent=None, attributes=None):

def find(self, name):
index = self._indices.get(name)
if index:
if index is not None:
return index, self.items[index]
return index, None

Expand Down Expand Up @@ -148,7 +149,7 @@ class Mask(Annotation):
_image = attrib()
label = attrib(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order = attrib(converter=int, default=0, kw_only=True)
z_order = attrib(default=0, validator=default_if_none(int), kw_only=True)

@property
def image(self):
Expand Down Expand Up @@ -274,31 +275,13 @@ def extract(self, instance_id):
def lazy_extract(self, instance_id):
return lambda: self.extract(instance_id)

def compute_iou(bbox_a, bbox_b):
aX, aY, aW, aH = bbox_a
bX, bY, bW, bH = bbox_b
in_right = min(aX + aW, bX + bW)
in_left = max(aX, bX)
in_top = max(aY, bY)
in_bottom = min(aY + aH, bY + bH)

in_w = max(0, in_right - in_left)
in_h = max(0, in_bottom - in_top)
intersection = in_w * in_h

a_area = aW * aH
b_area = bW * bH
union = a_area + b_area - intersection

return intersection / max(1.0, union)

@attrs
class _Shape(Annotation):
points = attrib(converter=lambda x:
[round(p, _COORDINATE_ROUNDING_DIGITS) for p in x])
label = attrib(converter=attr.converters.optional(int),
default=None, kw_only=True)
z_order = attrib(converter=int, default=0, kw_only=True)
z_order = attrib(default=0, validator=default_if_none(int), kw_only=True)

def get_area(self):
raise NotImplementedError()
Expand Down Expand Up @@ -386,7 +369,8 @@ def as_polygon(self):
]

def iou(self, other):
return compute_iou(self.get_bbox(), other.get_bbox())
from datumaro.util.annotation_util import bbox_iou
return bbox_iou(self.get_bbox(), other.get_bbox())

def wrap(item, **kwargs):
d = {'x': item.x, 'y': item.y, 'w': item.w, 'h': item.h}
Expand Down
Loading

0 comments on commit 17a5554

Please sign in to comment.