From fd8ff3594241691430c1737a38e0a8b30fd3e6fa Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Tue, 18 Aug 2020 17:38:21 +0300 Subject: [PATCH 1/8] Add exact diff command --- .../datumaro/cli/contexts/project/__init__.py | 100 ++++++-- .../datumaro/cli/contexts/project/diff.py | 2 +- datumaro/datumaro/components/comparator.py | 113 --------- datumaro/datumaro/components/extractor.py | 2 +- datumaro/datumaro/components/operations.py | 229 ++++++++++++++++++ datumaro/datumaro/util/__init__.py | 11 + datumaro/datumaro/util/test_utils.py | 3 +- datumaro/tests/test_diff.py | 157 +++++++----- 8 files changed, 421 insertions(+), 196 deletions(-) delete mode 100644 datumaro/datumaro/components/comparator.py diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index e6d5809b541..d28b1bfc41e 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -4,25 +4,26 @@ # SPDX-License-Identifier: MIT import argparse -from enum import Enum import json import logging as log import os import os.path as osp import shutil +from enum import Enum -from datumaro.components.project import Project, Environment, \ - PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG -from datumaro.components.comparator import Comparator +from datumaro.components.cli_plugin import CliPlugin from datumaro.components.dataset_filter import DatasetItemEncoder from datumaro.components.extractor import AnnotationType -from datumaro.components.cli_plugin import CliPlugin -from datumaro.components.operations import \ - compute_image_statistics, compute_ann_statistics +from datumaro.components.operations import (DistanceComparator, + ExactComparator, compute_ann_statistics, compute_image_statistics, mean_std) +from datumaro.components.project import \ + PROJECT_DEFAULT_CONFIG as DEFAULT_CONFIG +from datumaro.components.project import Environment, Project + +from ...util import (CliException, MultilineFormatter, add_subparser, + make_file_name) +from ...util.project import generate_next_file_name, load_project from .diff import DiffVisualizer -from ...util import add_subparser, CliException, MultilineFormatter, \ - make_file_name -from ...util.project import load_project, generate_next_file_name def build_create_parser(parser_ctor=argparse.ArgumentParser): @@ -503,12 +504,12 @@ def merge_command(args): def build_diff_parser(parser_ctor=argparse.ArgumentParser): parser = parser_ctor(help="Compare projects", description=""" - Compares two projects.|n + Compares two projects, match annotations by distance.|n |n Examples:|n - - Compare two projects, consider bboxes matching if their IoU > 0.7,|n + - Compare two projects, match boxes if IoU > 0.7,|n |s|s|s|sprint results to Tensorboard: - |s|sdiff path/to/other/project -o diff/ -f tensorboard --iou-thresh 0.7 + |s|sdiff path/to/other/project -o diff/ -v tensorboard --iou-thresh 0.7 """, formatter_class=MultilineFormatter) @@ -516,7 +517,7 @@ def build_diff_parser(parser_ctor=argparse.ArgumentParser): help="Directory of the second project to be compared") parser.add_argument('-o', '--output-dir', dest='dst_dir', default=None, help="Directory to save comparison results (default: do not save)") - parser.add_argument('-f', '--format', + parser.add_argument('-v', '--visualizer', default=DiffVisualizer.DEFAULT_FORMAT, choices=[f.name for f in DiffVisualizer.Format], help="Output format (default: %(default)s)") @@ -536,9 +537,7 @@ def diff_command(args): first_project = load_project(args.project_dir) second_project = load_project(args.other_project_dir) - comparator = Comparator( - iou_threshold=args.iou_thresh, - conf_threshold=args.conf_thresh) + comparator = DistanceComparator(iou_threshold=args.iou_thresh) dst_dir = args.dst_dir if dst_dir: @@ -556,7 +555,7 @@ def diff_command(args): dst_dir_existed = osp.exists(dst_dir) try: visualizer = DiffVisualizer(save_dir=dst_dir, comparator=comparator, - output_format=args.format) + output_format=args.visualizer) visualizer.save_dataset_diff( first_project.make_dataset(), second_project.make_dataset()) @@ -567,6 +566,70 @@ def diff_command(args): return 0 +def build_ediff_parser(parser_ctor=argparse.ArgumentParser): + parser = parser_ctor(help="Compare projects for equality", + description=""" + Compares two projects for equality.|n + |n + Examples:|n + - Compare two projects, exclude annotation group |n + |s|s|sand the 'is_crowd' attribute from comparison:|n + |s|sediff other/project/ -if group -ia is_crowd + """, + formatter_class=MultilineFormatter) + + parser.add_argument('other_project_dir', + help="Directory of the second project to be compared") + parser.add_argument('-iia', '--ignore-item-attr', action='append', + help="Ignore an item attribute (repeatable)") + parser.add_argument('-ia', '--ignore-attr', action='append', + help="Ignore an annotation attribute (repeatable)") + parser.add_argument('-if', '--ignore-field', + action='append', default=['id', 'group'], + help="Ignore an annotation field (repeatable, default: %(default)s)") + parser.add_argument('--all', action='store_true', + help="Include matches in the output") + parser.add_argument('-p', '--project', dest='project_dir', default='.', + help="Directory of the first project to be compared (default: current dir)") + parser.set_defaults(command=ediff_command) + + return parser + +def ediff_command(args): + first_project = load_project(args.project_dir) + second_project = load_project(args.other_project_dir) + + comparator = ExactComparator( + ignored_fields=args.ignore_field or [], + ignored_attrs=args.ignore_attr or [], + ignored_item_attrs=args.ignore_item_attr or []) + matches, mismatches, a_extra, b_extra, errors = \ + comparator.compare_datasets( + first_project.make_dataset(), second_project.make_dataset()) + output = { + "mismatches": mismatches, + "a_extra_items": sorted(a_extra), + "b_extra_items": sorted(b_extra), + "errors": errors, + } + if args.all: + output["matches"] = matches + + output_file = generate_next_file_name('diff', ext='.json') + with open(output_file, 'w') as f: + json.dump(output, f, indent=4, sort_keys=True) + + print("Found:") + print("The first project has %s unmatched items" % len(a_extra)) + print("The second project has %s unmatched items" % len(b_extra)) + print("%s item conflicts" % len(errors)) + print("%s matching annotations" % len(matches)) + print("%s mismatching annotations" % len(mismatches)) + + log.info("Output has been saved to '%s'" % output_file) + + return 0 + def build_transform_parser(parser_ctor=argparse.ArgumentParser): builtins = sorted(Environment().transforms.items) @@ -753,6 +816,7 @@ def build_parser(parser_ctor=argparse.ArgumentParser): add_subparser(subparsers, 'extract', build_extract_parser) add_subparser(subparsers, 'merge', build_merge_parser) add_subparser(subparsers, 'diff', build_diff_parser) + add_subparser(subparsers, 'ediff', build_ediff_parser) add_subparser(subparsers, 'transform', build_transform_parser) add_subparser(subparsers, 'info', build_info_parser) add_subparser(subparsers, 'stats', build_stats_parser) diff --git a/datumaro/datumaro/cli/contexts/project/diff.py b/datumaro/datumaro/cli/contexts/project/diff.py index 785c6c8ecde..571908f6679 100644 --- a/datumaro/datumaro/cli/contexts/project/diff.py +++ b/datumaro/datumaro/cli/contexts/project/diff.py @@ -217,7 +217,7 @@ def save_item_bbox_diff(self, item_a, item_b, diff): _, mispred, a_unmatched, b_unmatched = diff if 0 < len(a_unmatched) + len(b_unmatched) + len(mispred): - img_a = item_a.image.copy() + img_a = item_a.image.data.copy() img_b = img_a.copy() for a_bbox, b_bbox in mispred: self.draw_bbox(img_a, a_bbox, (0, 255, 0)) diff --git a/datumaro/datumaro/components/comparator.py b/datumaro/datumaro/components/comparator.py deleted file mode 100644 index 842a3963a98..00000000000 --- a/datumaro/datumaro/components/comparator.py +++ /dev/null @@ -1,113 +0,0 @@ - -# Copyright (C) 2019 Intel Corporation -# -# SPDX-License-Identifier: MIT - -from itertools import zip_longest -import numpy as np - -from datumaro.components.extractor import AnnotationType, LabelCategories - - -class Comparator: - def __init__(self, - iou_threshold=0.5, conf_threshold=0.9): - self.iou_threshold = iou_threshold - self.conf_threshold = conf_threshold - - @staticmethod - def iou(box_a, box_b): - return box_a.iou(box_b) - - # pylint: disable=no-self-use - def compare_dataset_labels(self, extractor_a, extractor_b): - a_label_cat = extractor_a.categories().get(AnnotationType.label) - b_label_cat = extractor_b.categories().get(AnnotationType.label) - if not a_label_cat and not b_label_cat: - return None - if not a_label_cat: - a_label_cat = LabelCategories() - if not b_label_cat: - b_label_cat = LabelCategories() - - mismatches = [] - for a_label, b_label in zip_longest(a_label_cat.items, b_label_cat.items): - if a_label != b_label: - mismatches.append((a_label, b_label)) - return mismatches - # pylint: enable=no-self-use - - def compare_item_labels(self, item_a, item_b): - conf_threshold = self.conf_threshold - - a_labels = set([ann.label for ann in item_a.annotations \ - if ann.type is AnnotationType.label and \ - conf_threshold < ann.attributes.get('score', 1)]) - b_labels = set([ann.label for ann in item_b.annotations \ - if ann.type is AnnotationType.label and \ - conf_threshold < ann.attributes.get('score', 1)]) - - a_unmatched = a_labels - b_labels - b_unmatched = b_labels - a_labels - matches = a_labels & b_labels - - return matches, a_unmatched, b_unmatched - - def compare_item_bboxes(self, item_a, item_b): - iou_threshold = self.iou_threshold - conf_threshold = self.conf_threshold - - a_boxes = [ann for ann in item_a.annotations \ - if ann.type is AnnotationType.bbox and \ - conf_threshold < ann.attributes.get('score', 1)] - b_boxes = [ann for ann in item_b.annotations \ - if ann.type is AnnotationType.bbox and \ - conf_threshold < ann.attributes.get('score', 1)] - a_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) - b_boxes.sort(key=lambda ann: 1 - ann.attributes.get('score', 1)) - - # a_matches: indices of b_boxes matched to a bboxes - # b_matches: indices of a_boxes matched to b bboxes - a_matches = -np.ones(len(a_boxes), dtype=int) - b_matches = -np.ones(len(b_boxes), dtype=int) - - iou_matrix = np.array([ - [self.iou(a, b) for b in b_boxes] for a in a_boxes - ]) - - # matches: boxes we succeeded to match completely - # mispred: boxes we succeeded to match, having label mismatch - matches = [] - mispred = [] - - for a_idx, a_bbox in enumerate(a_boxes): - if len(b_boxes) == 0: - break - matched_b = a_matches[a_idx] - iou_max = max(iou_matrix[a_idx, matched_b], iou_threshold) - for b_idx, b_bbox in enumerate(b_boxes): - if 0 <= b_matches[b_idx]: # assign a_bbox with max conf - continue - iou = iou_matrix[a_idx, b_idx] - if iou < iou_max: - continue - iou_max = iou - matched_b = b_idx - - if matched_b < 0: - continue - a_matches[a_idx] = matched_b - b_matches[matched_b] = a_idx - - b_bbox = b_boxes[matched_b] - - if a_bbox.label == b_bbox.label: - matches.append( (a_bbox, b_bbox) ) - else: - mispred.append( (a_bbox, b_bbox) ) - - # *_umatched: boxes of (*) we failed to match - a_unmatched = [a_boxes[i] for i, m in enumerate(a_matches) if m < 0] - b_unmatched = [b_boxes[i] for i, m in enumerate(b_matches) if m < 0] - - return matches, mispred, a_unmatched, b_unmatched diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index d7991cd121e..573f8d4ff40 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -46,7 +46,7 @@ def wrap(item, **kwargs): @attrs class Categories: attributes = attrib(factory=set, validator=default_if_none(set), - kw_only=True) + kw_only=True, eq=False) @attrs class LabelCategories(Categories): diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 9e63d3a7e84..e27a296e357 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -1003,3 +1003,232 @@ def get_label(ann): } for c, (bin_min, bin_max) in zip(hist, zip(bins[:-1], bins[1:]))] return stats + +@attrs +class DistanceComparator: + iou_threshold = attrib(converter=float, default=0.5) + + @staticmethod + def match_datasets(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + + @staticmethod + def match_classes(a, b): + a_label_cat = a.categories().get(AnnotationType.label, LabelCategories()) + b_label_cat = b.categories().get(AnnotationType.label, LabelCategories()) + + a_labels = set(c.name for c in a_label_cat) + b_labels = set(c.name for c in b_label_cat) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def match_annotations(self, item_a, item_b): + return { t: self._match_ann_type(t, item_a, item_b) } + + def _match_ann_type(self, t, *args): + if t == AnnotationType.label: + return self.match_labels(*args) + elif t == AnnotationType.bbox: + return self.match_boxes(*args) + elif t == AnnotationType.polygon: + return self.match_polygons(*args) + elif t == AnnotationType.mask: + return self.match_masks(*args) + elif t == AnnotationType.points: + return self.match_points(*args) + elif t == AnnotationType.polyline: + return self.match_lines(*args) + else: + raise NotImplementedError("Unexpected annotation type %s" % t) + + @staticmethod + def _get_ann_type(t, item): + return get_ann_type(item.annotations, t) + + def match_labels(self, item_a, item_b): + a_labels = set(a.label for a in + self._get_ann_type(AnnotationType.label, item_a)) + b_labels = set(a.label for a in + self._get_ann_type(AnnotationType.label, item_b)) + + matches = a_labels & b_labels + a_unmatched = a_labels - b_labels + b_unmatched = b_labels - a_labels + return matches, a_unmatched, b_unmatched + + def _match_segments(self, t, item_a, item_b): + a_boxes = self._get_ann_type(t, item_a) + b_boxes = self._get_ann_type(t, item_b) + return match_segments(a_boxes, b_boxes, dist_thresh=self.iou_threshold) + + def match_polygons(self, item_a, item_b): + return self._match_segments(AnnotationType.polygon, item_a, item_b) + + def match_masks(self, item_a, item_b): + return self._match_segments(AnnotationType.mask, item_a, item_b) + + def match_boxes(self, item_a, item_b): + return self._match_segments(AnnotationType.bbox, item_a, item_b) + + def match_points(self, item_a, item_b): + a_points = self._get_ann_type(AnnotationType.points, item_a) + b_points = self._get_ann_type(AnnotationType.points, item_b) + + instance_map = {} + for s in sources: + s_instances = find_instances(s) + for inst in s_instances: + inst_bbox = max_bbox(inst) + for ann in inst: + instance_map[id(ann)] = [inst, inst_bbox] + matcher = PointsMatcher(instance_map=instance_map) + distance = lambda a, b: matcher.distance(a, b) + + return match_segments(a_points, b_points, + dist_thresh=self.iou_threshold, distance=distance) + + def match_lines(self, item_a, item_b): + a_lines = self._get_ann_type(AnnotationType.polyline, item_a) + b_lines = self._get_ann_type(AnnotationType.polyline, item_b) + + matcher = LineMatcher() + distance = lambda a, b: matcher.distance(a, b) + + return match_segments(a_lines, b_lines, + dist_thresh=self.iou_threshold, distance=distance) + +@attrs +class ExactComparator: + ignored_fields = attrib(kw_only=True, factory=set, converter=set) + ignored_attrs = attrib(kw_only=True, factory=set, converter=set) + ignored_item_attrs = attrib(kw_only=True, factory=set, converter=set) + + _test = attrib(init=False, type=TestCase) + + def __attrs_post_init__(self): + self._test = TestCase() + self._test.maxDiff = None + + + @staticmethod + def _match_datasets(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + + def _compare_categories(self, a, b): + test = self._test + + errors = [] + try: + test.assertEqual( + sorted(a, key=lambda t: t.value), + sorted(b, key=lambda t: t.value) + ) + except AssertionError as e: + errors.append({'type': 'categories', 'message': str(e)}) + + if AnnotationType.label in a: + try: + test.assertEqual( + a[AnnotationType.label].items, + b[AnnotationType.label].items, + ) + except AssertionError as e: + errors.append({'type': 'labels', 'message': str(e)}) + if AnnotationType.mask in a: + try: + test.assertEqual( + a[AnnotationType.mask].colormap, + b[AnnotationType.mask].colormap, + ) + except AssertionError as e: + errors.append({'type': 'colormap', 'message': str(e)}) + if AnnotationType.points in a: + try: + test.assertEqual( + a[AnnotationType.points].items, + b[AnnotationType.points].items, + ) + except AssertionError as e: + errors.append({'type': 'points', 'message': str(e)}) + return errors + + def _compare_annotations(self, a, b): + ignored_fields = self.ignored_fields + ignored_attrs = self.ignored_attrs + + a_fields = { k: None for k in vars(a) if k in ignored_fields} + b_fields = { k: None for k in vars(b) if k in ignored_fields} + if 'attributes' not in ignored_fields: + a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs) + b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs) + + result = a.wrap(**a_fields) == b.wrap(**b_fields) + + return result + + def compare_datasets(self, a, b): + test = self._test + + errors = [] + + errors.extend(self._compare_categories(a.categories(), b.categories())) + + matched = [] + unmatched = [] + + items, a_extra_items, b_extra_items = self._match_datasets(a, b) + + if a.categories().get(AnnotationType.label) != \ + b.categories().get(AnnotationType.label): + return matched, unmatched, a_extra_items, b_extra_items, errors + + for item_id in items: + item_a = a.get(*item_id) + item_b = b.get(*item_id) + + try: + test.assertEqual( + filter_dict(item_a.attributes, self.ignored_item_attrs), + filter_dict(item_b.attributes, self.ignored_item_attrs) + ) + except AssertionError as e: + errors.append({'type': 'item_attr', + 'item': item_id, 'message': str(e)}) + + b_annotations = item_b.annotations[:] + for ann_a in item_a.annotations: + ann_b_candidates = [x for x in item_b.annotations + if x.type == ann_a.type] + + ann_b = find(enumerate(self._compare_annotations(ann_a, x) + for x in ann_b_candidates), lambda x: x[1]) + if ann_b is None: + unmatched.append({ + 'item': item_id, 'source': 'a', 'ann': str(ann_a), + }) + continue + else: + ann_b = ann_b_candidates[ann_b[0]] + + b_annotations.remove(ann_b) # avoid repeats + matched.append({'item': item_id, 'a': str(ann_a), 'b': str(ann_b)}) + + for ann_b in b_annotations: + unmatched.append({'item': item_id, 'source': 'b', 'ann': str(ann_b)}) + + return matched, unmatched, a_extra_items, b_extra_items, errors \ No newline at end of file diff --git a/datumaro/datumaro/util/__init__.py b/datumaro/datumaro/util/__init__.py index 293bb5f62f3..dd3e0c21033 100644 --- a/datumaro/datumaro/util/__init__.py +++ b/datumaro/datumaro/util/__init__.py @@ -88,3 +88,14 @@ def str_to_bool(s): return False else: raise ValueError("Can't convert value '%s' to bool" % s) + +def ensure_cls(c): + def converter(arg): + if isinstance(arg, c): + return arg + else: + return c(**arg) + return converter + +def filter_dict(d, exclude_keys): + return { k: v for k, v in d.items() if k not in exclude_keys } \ No newline at end of file diff --git a/datumaro/datumaro/util/test_utils.py b/datumaro/datumaro/util/test_utils.py index f93a74ce1b3..62973ca5a0a 100644 --- a/datumaro/datumaro/util/test_utils.py +++ b/datumaro/datumaro/util/test_utils.py @@ -100,8 +100,7 @@ def compare_datasets(test, expected, actual, ignored_attrs=None): ann_b = find(ann_b_matches, lambda x: _compare_annotations(x, ann_a, ignored_attrs=ignored_attrs)) if ann_b is None: - test.assertEqual(ann_a, ann_b, - 'ann %s, candidates %s' % (ann_a, ann_b_matches)) + test.fail('ann %s, candidates %s' % (ann_a, ann_b_matches)) item_b.annotations.remove(ann_b) # avoid repeats def compare_datasets_strict(test, expected, actual): diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 9ad9c1de6fd..4ea145af58a 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -1,123 +1,97 @@ -from unittest import TestCase +import numpy as np + +from datumaro.components.extractor import DatasetItem, Label, Bbox, Caption, Mask, Points +from datumaro.components.project import Dataset +from datumaro.components.operations import DistanceComparator, ExactComparator -from datumaro.components.extractor import DatasetItem, Label, Bbox -from datumaro.components.comparator import Comparator +from unittest import TestCase -class DiffTest(TestCase): +class DistanceComparatorTest(TestCase): def test_no_bbox_diff_with_same_item(self): detections = 3 anns = [ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) ] item = DatasetItem(id=0, annotations=anns) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item, item) + result = comp.match_boxes(item, item) matches, mispred, a_greater, b_greater = result self.assertEqual(0, len(mispred)) self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) - self.assertEqual(len([it for it in item.annotations \ - if conf_thresh < it.attributes['score']]), - len(matches)) + self.assertEqual(len(item.annotations), len(matches)) for a_bbox, b_bbox in matches: self.assertLess(iou_thresh, a_bbox.iou(b_bbox)) self.assertEqual(a_bbox.label, b_bbox.label) - self.assertLess(conf_thresh, a_bbox.attributes['score']) - self.assertLess(conf_thresh, b_bbox.attributes['score']) def test_can_find_bbox_with_wrong_label(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) ]) item2 = DatasetItem(id=2, annotations=[ - Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count) + for i in range(detections) ]) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item1, item2) + result = comp.match_boxes(item1, item2) matches, mispred, a_greater, b_greater = result - self.assertEqual(len([it for it in item1.annotations \ - if conf_thresh < it.attributes['score']]), - len(mispred)) + self.assertEqual(len(item1.annotations), len(mispred)) self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) self.assertEqual(0, len(matches)) for a_bbox, b_bbox in mispred: self.assertLess(iou_thresh, a_bbox.iou(b_bbox)) self.assertEqual((a_bbox.label + 1) % class_count, b_bbox.label) - self.assertLess(conf_thresh, a_bbox.attributes['score']) - self.assertLess(conf_thresh, b_bbox.attributes['score']) def test_can_find_missing_boxes(self): detections = 3 class_count = 2 item1 = DatasetItem(id=1, annotations=[ - Bbox(i * 10, 10, 10, 10, label=i, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) if i % 2 == 0 + Bbox(i * 10, 10, 10, 10, label=i) + for i in range(detections) if i % 2 == 0 ]) item2 = DatasetItem(id=2, annotations=[ - Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count, - attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) if i % 2 == 1 + Bbox(i * 10, 10, 10, 10, label=(i + 1) % class_count) + for i in range(detections) if i % 2 == 1 ]) iou_thresh = 0.5 - conf_thresh = 0.5 - comp = Comparator( - iou_threshold=iou_thresh, conf_threshold=conf_thresh) + comp = DistanceComparator(iou_threshold=iou_thresh) - result = comp.compare_item_bboxes(item1, item2) + result = comp.match_boxes(item1, item2) matches, mispred, a_greater, b_greater = result self.assertEqual(0, len(mispred)) - self.assertEqual(len([it for it in item1.annotations \ - if conf_thresh < it.attributes['score']]), - len(a_greater)) - self.assertEqual(len([it for it in item2.annotations \ - if conf_thresh < it.attributes['score']]), - len(b_greater)) + self.assertEqual(len(item1.annotations), len(a_greater)) + self.assertEqual(len(item2.annotations), len(b_greater)) self.assertEqual(0, len(matches)) def test_no_label_diff_with_same_item(self): detections = 3 anns = [ - Label(i, attributes={'score': (1.0 + i) / detections}) \ - for i in range(detections) + Label(i) for i in range(detections) ] item = DatasetItem(id=1, annotations=anns) - conf_thresh = 0.5 - comp = Comparator(conf_threshold=conf_thresh) - - result = comp.compare_item_labels(item, item) + result = DistanceComparator().match_labels(item, item) matches, a_greater, b_greater = result self.assertEqual(0, len(a_greater)) self.assertEqual(0, len(b_greater)) - self.assertEqual(len([it for it in item.annotations \ - if conf_thresh < it.attributes['score']]), - len(matches)) + self.assertEqual(len(item.annotations), len(matches)) def test_can_find_wrong_label(self): item1 = DatasetItem(id=1, annotations=[ @@ -131,12 +105,73 @@ def test_can_find_wrong_label(self): Label(4), ]) - conf_thresh = 0.5 - comp = Comparator(conf_threshold=conf_thresh) - - result = comp.compare_item_labels(item1, item2) + result = DistanceComparator().match_labels(item1, item2) matches, a_greater, b_greater = result self.assertEqual(2, len(a_greater)) self.assertEqual(2, len(b_greater)) - self.assertEqual(1, len(matches)) \ No newline at end of file + self.assertEqual(1, len(matches)) + +class ExactComparatorTest(TestCase): + def test_class_comparison(self): + a = Dataset.from_iterable([], categories=['a', 'b', 'c']) + b = Dataset.from_iterable([], categories=['b', 'c']) + + comp = ExactComparator() + _, _, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(1, len(errors), errors) + + def test_item_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, subset='train'), + DatasetItem(id=2, subset='test', attributes={'x': 1}), + ], categories=['a', 'b', 'c']) + + b = Dataset.from_iterable([ + DatasetItem(id=2, subset='test'), + DatasetItem(id=3), + ], categories=['a', 'b', 'c']) + + comp = ExactComparator() + _, _, a_extra_items, b_extra_items, errors = comp.compare_datasets(a, b) + + self.assertEqual({('1', 'train')}, a_extra_items) + self.assertEqual({('3', '')}, b_extra_items) + self.assertEqual(1, len(errors), errors) + + def test_annotation_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('hello'), # unmatched + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + b = Dataset.from_iterable([ + DatasetItem(id=1, annotations=[ + Caption('world', group=5), + Label(2, attributes={ 'x': 1, 'y': '2', }), + Bbox(1, 2, 3, 4, label=4, z_order=1, attributes={ + 'score': 1.0, + }), + Bbox(5, 6, 7, 8, group=5), + Bbox(5, 6, 7, 8, group=5), # unmatched + Points([1, 2, 2, 0, 1, 1], label=0, z_order=4), + Mask(label=3, z_order=2, image=np.ones((2, 3))), + ]), + ], categories=['a', 'b', 'c', 'd']) + + comp = ExactComparator() + matched, unmatched, _, _, errors = comp.compare_datasets(a, b) + + self.assertEqual(6, len(matched), matched) + self.assertEqual(2, len(unmatched), unmatched) + self.assertEqual(0, len(errors), errors) \ No newline at end of file From bf8a4a1306d31d7be36956cfecc619c90823aa1f Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Tue, 18 Aug 2020 17:40:05 +0300 Subject: [PATCH 2/8] Update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8df9cb1cbd5..e4ea48c9319 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Siammask tracker as DL serverless function () - [Datumaro] Added model info and source info commands () - [Datumaro] Dataset statistics () -- [Datumaro] Multi-dataset merge (https://github.com/opencv/cvat/pull/1695) +- [Datumaro] Multi-dataset merge () +- [Datumaro] CLI command for dataset equality comparison () ### Changed - Shape coordinates are rounded to 2 digits in dumped annotations () From d81b9de7b9301ad12ad4023b457375d276dc467b Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Tue, 18 Aug 2020 17:43:50 +0300 Subject: [PATCH 3/8] fix --- datumaro/datumaro/components/operations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index e27a296e357..d2fc371b805 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -11,11 +11,12 @@ import cv2 import numpy as np from attr import attrib, attrs +from unittest import TestCase from datumaro.components.cli_plugin import CliPlugin from datumaro.components.extractor import AnnotationType, Bbox, Label from datumaro.components.project import Dataset -from datumaro.util import find +from datumaro.util import find, filter_dict from datumaro.util.attrs_util import ensure_cls from datumaro.util.annotation_util import (segment_iou, bbox_iou, mean_bbox, OKS, find_instances, max_bbox, smooth_line) From 75172d90297955f9c3d084ede0e130bbbf640dea Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Tue, 18 Aug 2020 17:46:26 +0300 Subject: [PATCH 4/8] fix merge --- datumaro/datumaro/util/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/datumaro/datumaro/util/__init__.py b/datumaro/datumaro/util/__init__.py index dd3e0c21033..010057d54c6 100644 --- a/datumaro/datumaro/util/__init__.py +++ b/datumaro/datumaro/util/__init__.py @@ -89,13 +89,5 @@ def str_to_bool(s): else: raise ValueError("Can't convert value '%s' to bool" % s) -def ensure_cls(c): - def converter(arg): - if isinstance(arg, c): - return arg - else: - return c(**arg) - return converter - def filter_dict(d, exclude_keys): return { k: v for k, v in d.items() if k not in exclude_keys } \ No newline at end of file From 3478727a948eb4432fc286a67bfbf22318d2ef5f Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Wed, 26 Aug 2020 18:01:59 +0300 Subject: [PATCH 5/8] Add image matching, add test --- .../datumaro/cli/contexts/project/__init__.py | 15 +- datumaro/datumaro/components/extractor.py | 4 + datumaro/datumaro/components/operations.py | 212 +++++++++++++----- datumaro/tests/test_diff.py | 57 ++++- 4 files changed, 230 insertions(+), 58 deletions(-) diff --git a/datumaro/datumaro/cli/contexts/project/__init__.py b/datumaro/datumaro/cli/contexts/project/__init__.py index d28b1bfc41e..f969b418273 100644 --- a/datumaro/datumaro/cli/contexts/project/__init__.py +++ b/datumaro/datumaro/cli/contexts/project/__init__.py @@ -581,12 +581,14 @@ def build_ediff_parser(parser_ctor=argparse.ArgumentParser): parser.add_argument('other_project_dir', help="Directory of the second project to be compared") parser.add_argument('-iia', '--ignore-item-attr', action='append', - help="Ignore an item attribute (repeatable)") + help="Ignore item attribute (repeatable)") parser.add_argument('-ia', '--ignore-attr', action='append', - help="Ignore an annotation attribute (repeatable)") + help="Ignore annotation attribute (repeatable)") parser.add_argument('-if', '--ignore-field', action='append', default=['id', 'group'], - help="Ignore an annotation field (repeatable, default: %(default)s)") + help="Ignore annotation field (repeatable, default: %(default)s)") + parser.add_argument('--match-images', action='store_true', + help='Match dataset items by images instead of ids') parser.add_argument('--all', action='store_true', help="Include matches in the output") parser.add_argument('-p', '--project', dest='project_dir', default='.', @@ -600,9 +602,10 @@ def ediff_command(args): second_project = load_project(args.other_project_dir) comparator = ExactComparator( - ignored_fields=args.ignore_field or [], - ignored_attrs=args.ignore_attr or [], - ignored_item_attrs=args.ignore_item_attr or []) + match_images=args.match_images, + ignored_fields=args.ignore_field, + ignored_attrs=args.ignore_attr, + ignored_item_attrs=args.ignore_item_attr) matches, mismatches, a_extra, b_extra, errors = \ comparator.compare_datasets( first_project.make_dataset(), second_project.make_dataset()) diff --git a/datumaro/datumaro/components/extractor.py b/datumaro/datumaro/components/extractor.py index 573f8d4ff40..0473a250b46 100644 --- a/datumaro/datumaro/components/extractor.py +++ b/datumaro/datumaro/components/extractor.py @@ -137,6 +137,8 @@ def inverse_colormap(self): def __eq__(self, other): if not super().__eq__(other): return False + if not isinstance(other, __class__): + return False for label_id, my_color in self.colormap.items(): other_color = other.colormap.get(label_id) if not np.array_equal(my_color, other_color): @@ -179,6 +181,8 @@ def paint(self, colormap): def __eq__(self, other): if not super().__eq__(other): return False + if not isinstance(other, __class__): + return False return \ (self.label == other.label) and \ (self.z_order == other.z_order) and \ diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index d2fc371b805..2b86f73331b 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -5,6 +5,7 @@ from collections import OrderedDict from copy import deepcopy +import hashlib import logging as log import attr @@ -17,7 +18,7 @@ from datumaro.components.extractor import AnnotationType, Bbox, Label from datumaro.components.project import Dataset from datumaro.util import find, filter_dict -from datumaro.util.attrs_util import ensure_cls +from datumaro.util.attrs_util import ensure_cls, default_if_none from datumaro.util.annotation_util import (segment_iou, bbox_iou, mean_bbox, OKS, find_instances, max_bbox, smooth_line) @@ -1107,33 +1108,77 @@ def match_lines(self, item_a, item_b): return match_segments(a_lines, b_lines, dist_thresh=self.iou_threshold, distance=distance) +def match_items_by_id(a, b): + a_items = set((item.id, item.subset) for item in a) + b_items = set((item.id, item.subset) for item in b) + + matches = a_items & b_items + matches = [([m], [m]) for m in matches] + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + return matches, a_unmatched, b_unmatched + +def match_items_by_image_hash(a, b): + def _hash(item): + if not item.image.has_data: + log.warning("Image (%s, %s) has no image " + "data, counted as unmatched", item.id, item.subset) + return None + return hashlib.md5(item.image.data.tobytes()).hexdigest() + + def _build_hashmap(source): + d = {} + for item in source: + h = _hash(item) + if h is None: + h = str(id(item)) # anything unique + d.setdefault(h, []).append((item.id, item.subset)) + return d + + a_hash = _build_hashmap(a) + b_hash = _build_hashmap(b) + + a_items = set(a_hash) + b_items = set(b_hash) + + matches = a_items & b_items + a_unmatched = a_items - b_items + b_unmatched = b_items - a_items + + matches = [(a_hash[h], b_hash[h]) for h in matches] + a_unmatched = set(i for h in a_unmatched for i in a_hash[h]) + b_unmatched = set(i for h in b_unmatched for i in b_hash[h]) + + return matches, a_unmatched, b_unmatched + @attrs class ExactComparator: - ignored_fields = attrib(kw_only=True, factory=set, converter=set) - ignored_attrs = attrib(kw_only=True, factory=set, converter=set) - ignored_item_attrs = attrib(kw_only=True, factory=set, converter=set) + match_images = attrib(kw_only=True, type=bool, default=False) + ignored_fields = attrib(kw_only=True, + factory=set, validator=default_if_none(set)) + ignored_attrs = attrib(kw_only=True, + factory=set, validator=default_if_none(set)) + ignored_item_attrs = attrib(kw_only=True, + factory=set, validator=default_if_none(set)) _test = attrib(init=False, type=TestCase) + errors = attrib(init=False, type=list) def __attrs_post_init__(self): self._test = TestCase() self._test.maxDiff = None - @staticmethod - def _match_datasets(a, b): - a_items = set((item.id, item.subset) for item in a) - b_items = set((item.id, item.subset) for item in b) - - matches = a_items & b_items - a_unmatched = a_items - b_items - b_unmatched = b_items - a_items - return matches, a_unmatched, b_unmatched + def _match_items(self, a, b): + if self.match_images: + return match_items_by_image_hash(a, b) + else: + return match_items_by_id(a, b) def _compare_categories(self, a, b): test = self._test + errors = self.errors - errors = [] try: test.assertEqual( sorted(a, key=lambda t: t.value), @@ -1166,14 +1211,13 @@ def _compare_categories(self, a, b): ) except AssertionError as e: errors.append({'type': 'points', 'message': str(e)}) - return errors def _compare_annotations(self, a, b): ignored_fields = self.ignored_fields ignored_attrs = self.ignored_attrs - a_fields = { k: None for k in vars(a) if k in ignored_fields} - b_fields = { k: None for k in vars(b) if k in ignored_fields} + a_fields = { k: None for k in vars(a) if k in ignored_fields } + b_fields = { k: None for k in vars(b) if k in ignored_fields } if 'attributes' not in ignored_fields: a_fields['attributes'] = filter_dict(a.attributes, ignored_attrs) b_fields['attributes'] = filter_dict(b.attributes, ignored_attrs) @@ -1182,54 +1226,120 @@ def _compare_annotations(self, a, b): return result - def compare_datasets(self, a, b): + def _compare_items(self, item_a, item_b): test = self._test + a_id = (item_a.id, item_a.subset) + b_id = (item_b.id, item_b.subset) + + matched = [] + unmatched = [] errors = [] - errors.extend(self._compare_categories(a.categories(), b.categories())) + try: + test.assertEqual( + filter_dict(item_a.attributes, self.ignored_item_attrs), + filter_dict(item_b.attributes, self.ignored_item_attrs) + ) + except AssertionError as e: + errors.append({'type': 'item_attr', + 'a_item': a_id, 'b_item': b_id, 'message': str(e)}) + + b_annotations = item_b.annotations[:] + for ann_a in item_a.annotations: + ann_b_candidates = [x for x in item_b.annotations + if x.type == ann_a.type] + + ann_b = find(enumerate(self._compare_annotations(ann_a, x) + for x in ann_b_candidates), lambda x: x[1]) + if ann_b is None: + unmatched.append({ + 'item': a_id, 'source': 'a', 'ann': str(ann_a), + }) + continue + else: + ann_b = ann_b_candidates[ann_b[0]] + + b_annotations.remove(ann_b) # avoid repeats + matched.append({'a_item': a_id, 'b_item': b_id, + 'a': str(ann_a), 'b': str(ann_b)}) + + for ann_b in b_annotations: + unmatched.append({'item': b_id, 'source': 'b', 'ann': str(ann_b)}) + + return matched, unmatched, errors + + def compare_datasets(self, a, b): + self.errors = [] + errors = self.errors + + self._compare_categories(a.categories(), b.categories()) matched = [] unmatched = [] - items, a_extra_items, b_extra_items = self._match_datasets(a, b) + matches, a_unmatched, b_unmatched = self._match_items(a, b) if a.categories().get(AnnotationType.label) != \ b.categories().get(AnnotationType.label): - return matched, unmatched, a_extra_items, b_extra_items, errors + return matched, unmatched, a_unmatched, b_unmatched, errors - for item_id in items: - item_a = a.get(*item_id) - item_b = b.get(*item_id) + _dist = lambda s: len(s[1]) + len(s[2]) + for a_ids, b_ids in matches: + # build distance matrix + match_status = {} # (a_id, b_id): [matched, unmatched, errors] + a_matches = { a_id: None for a_id in a_ids } + b_matches = { b_id: None for b_id in b_ids } - try: - test.assertEqual( - filter_dict(item_a.attributes, self.ignored_item_attrs), - filter_dict(item_b.attributes, self.ignored_item_attrs) - ) - except AssertionError as e: - errors.append({'type': 'item_attr', - 'item': item_id, 'message': str(e)}) - - b_annotations = item_b.annotations[:] - for ann_a in item_a.annotations: - ann_b_candidates = [x for x in item_b.annotations - if x.type == ann_a.type] - - ann_b = find(enumerate(self._compare_annotations(ann_a, x) - for x in ann_b_candidates), lambda x: x[1]) - if ann_b is None: - unmatched.append({ - 'item': item_id, 'source': 'a', 'ann': str(ann_a), - }) + for a_id in a_ids: + item_a = a.get(*a_id) + candidates = {} + + for b_id in b_ids: + item_b = b.get(*b_id) + + i_m, i_um, i_err = self._compare_items(item_a, item_b) + candidates[b_id] = [i_m, i_um, i_err] + + if len(i_um) == 0: + a_matches[a_id] = b_id + b_matches[b_id] = a_id + matched.extend(i_m) + errors.extend(i_err) + break + + match_status[a_id] = candidates + + # assign + for a_id in a_ids: + if len(b_ids) == 0: + break + + # find the closest, ignore already assigned + matched_b = a_matches[a_id] + if matched_b is not None: continue - else: - ann_b = ann_b_candidates[ann_b[0]] + min_dist = -1 + for b_id in b_ids: + if b_matches[b_id] is not None: + continue + d = _dist(match_status[a_id][b_id]) + if d < min_dist and 0 <= min_dist: + continue + min_dist = d + matched_b = b_id + + if matched_b is None: + continue + a_matches[a_id] = matched_b + b_matches[matched_b] = a_id - b_annotations.remove(ann_b) # avoid repeats - matched.append({'item': item_id, 'a': str(ann_a), 'b': str(ann_b)}) + m = match_status[a_id][matched_b] + matched.extend(m[0]) + unmatched.extend(m[1]) + errors.extend(m[2]) - for ann_b in b_annotations: - unmatched.append({'item': item_id, 'source': 'b', 'ann': str(ann_b)}) + a_unmatched |= set(a_id for a_id, m in a_matches.items() if not m) + b_unmatched |= set(b_id for b_id, m in b_matches.items() if not m) - return matched, unmatched, a_extra_items, b_extra_items, errors \ No newline at end of file + return matched, unmatched, a_unmatched, b_unmatched, errors \ No newline at end of file diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 4ea145af58a..2bf41253bdc 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -174,4 +174,59 @@ def test_annotation_comparison(self): self.assertEqual(6, len(matched), matched) self.assertEqual(2, len(unmatched), unmatched) - self.assertEqual(0, len(errors), errors) \ No newline at end of file + self.assertEqual(0, len(errors), errors) + + def test_image_comparison(self): + a = Dataset.from_iterable([ + DatasetItem(id=11, image=np.ones((5, 4, 3)), annotations=[ + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=12, image=np.ones((5, 4, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=13, image=np.ones((5, 4, 3)), annotations=[ + Bbox(9, 10, 11, 12), # mismatch + ]), + + DatasetItem(id=14, image=np.zeros((5, 4, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ], attributes={ 'a': 1 }), + + DatasetItem(id=15, image=np.zeros((5, 5, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ]), + ], categories=['a', 'b', 'c', 'd']) + + b = Dataset.from_iterable([ + DatasetItem(id=21, image=np.ones((5, 4, 3)), annotations=[ + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=22, image=np.ones((5, 4, 3)), annotations=[ + Bbox(1, 2, 3, 4), + Bbox(5, 6, 7, 8), + ]), + DatasetItem(id=23, image=np.ones((5, 4, 3)), annotations=[ + Bbox(10, 10, 11, 12), # mismatch + ]), + + DatasetItem(id=24, image=np.zeros((5, 4, 3)), annotations=[ + Bbox(6, 6, 7, 8), # 1 ann missing, mismatch + ], attributes={ 'a': 2 }), + + DatasetItem(id=25, image=np.zeros((4, 4, 3)), annotations=[ + Bbox(6, 6, 7, 8), + ]), + ], categories=['a', 'b', 'c', 'd']) + + comp = ExactComparator(match_images=True) + matched_ann, unmatched_ann, a_unmatched, b_unmatched, errors = \ + comp.compare_datasets(a, b) + + self.assertEqual(3, len(matched_ann), matched_ann) + self.assertEqual(5, len(unmatched_ann), unmatched_ann) + self.assertEqual(1, len(a_unmatched), a_unmatched) + self.assertEqual(1, len(b_unmatched), b_unmatched) + self.assertEqual(1, len(errors), errors) \ No newline at end of file From 9589e8a6cf3d35b10c1de75121284ec8496b5e62 Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Mon, 31 Aug 2020 14:38:08 +0300 Subject: [PATCH 6/8] Add point matching test --- datumaro/datumaro/components/operations.py | 10 ++++---- datumaro/tests/test_diff.py | 27 ++++++++++++++++++---- 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 2b86f73331b..6e3d2dd3d98 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -587,7 +587,7 @@ class MaskMatcher(_ShapeMatcher): @attrs(kw_only=True) class PointsMatcher(_ShapeMatcher): - sigma = attrib(converter=list, default=None) + sigma = attrib(type=list, default=None) instance_map = attrib(converter=dict) def distance(self, a, b): @@ -1086,27 +1086,25 @@ def match_points(self, item_a, item_b): b_points = self._get_ann_type(AnnotationType.points, item_b) instance_map = {} - for s in sources: + for s in [item_a.annotations, item_b.annotations]: s_instances = find_instances(s) for inst in s_instances: inst_bbox = max_bbox(inst) for ann in inst: instance_map[id(ann)] = [inst, inst_bbox] matcher = PointsMatcher(instance_map=instance_map) - distance = lambda a, b: matcher.distance(a, b) return match_segments(a_points, b_points, - dist_thresh=self.iou_threshold, distance=distance) + dist_thresh=self.iou_threshold, distance=matcher.distance) def match_lines(self, item_a, item_b): a_lines = self._get_ann_type(AnnotationType.polyline, item_a) b_lines = self._get_ann_type(AnnotationType.polyline, item_b) matcher = LineMatcher() - distance = lambda a, b: matcher.distance(a, b) return match_segments(a_lines, b_lines, - dist_thresh=self.iou_threshold, distance=distance) + dist_thresh=self.iou_threshold, distance=matcher.distance) def match_items_by_id(a, b): a_items = set((item.id, item.subset) for item in a) diff --git a/datumaro/tests/test_diff.py b/datumaro/tests/test_diff.py index 2bf41253bdc..33dd79da0ff 100644 --- a/datumaro/tests/test_diff.py +++ b/datumaro/tests/test_diff.py @@ -1,6 +1,7 @@ import numpy as np -from datumaro.components.extractor import DatasetItem, Label, Bbox, Caption, Mask, Points +from datumaro.components.extractor import (DatasetItem, Label, Bbox, + Caption, Mask, Points) from datumaro.components.project import Dataset from datumaro.components.operations import DistanceComparator, ExactComparator @@ -81,9 +82,7 @@ def test_can_find_missing_boxes(self): def test_no_label_diff_with_same_item(self): detections = 3 - anns = [ - Label(i) for i in range(detections) - ] + anns = [ Label(i) for i in range(detections) ] item = DatasetItem(id=1, annotations=anns) result = DistanceComparator().match_labels(item, item) @@ -112,6 +111,26 @@ def test_can_find_wrong_label(self): self.assertEqual(2, len(b_greater)) self.assertEqual(1, len(matches)) + def test_can_match_points(self): + item1 = DatasetItem(id=1, annotations=[ + Points([1, 2, 2, 0, 1, 1], label=0), + + Points([3, 5, 5, 7, 5, 3], label=0), + ]) + item2 = DatasetItem(id=2, annotations=[ + Points([1.5, 2, 2, 0.5, 1, 1.5], label=0), + + Points([5, 7, 7, 7, 7, 5], label=0), + ]) + + result = DistanceComparator().match_points(item1, item2) + + matches, mismatches, a_greater, b_greater = result + self.assertEqual(1, len(a_greater)) + self.assertEqual(1, len(b_greater)) + self.assertEqual(1, len(matches)) + self.assertEqual(0, len(mismatches)) + class ExactComparatorTest(TestCase): def test_class_comparison(self): a = Dataset.from_iterable([], categories=['a', 'b', 'c']) From 2ad4682418963c522f4fb1c401be731a625e8e3e Mon Sep 17 00:00:00 2001 From: Zhiltsov Max Date: Mon, 31 Aug 2020 14:39:53 +0300 Subject: [PATCH 7/8] linter --- datumaro/datumaro/components/operations.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datumaro/datumaro/components/operations.py b/datumaro/datumaro/components/operations.py index 6e3d2dd3d98..2e3a68136db 100644 --- a/datumaro/datumaro/components/operations.py +++ b/datumaro/datumaro/components/operations.py @@ -1037,6 +1037,7 @@ def match_annotations(self, item_a, item_b): return { t: self._match_ann_type(t, item_a, item_b) } def _match_ann_type(self, t, *args): + # pylint: disable=no-value-for-parameter if t == AnnotationType.label: return self.match_labels(*args) elif t == AnnotationType.bbox: @@ -1049,6 +1050,7 @@ def _match_ann_type(self, t, *args): return self.match_points(*args) elif t == AnnotationType.polyline: return self.match_lines(*args) + # pylint: enable=no-value-for-parameter else: raise NotImplementedError("Unexpected annotation type %s" % t) From 375766713b0aa6c8980e1b197a172246faacc9c2 Mon Sep 17 00:00:00 2001 From: Nikita Manovich Date: Wed, 2 Sep 2020 22:21:23 +0300 Subject: [PATCH 8/8] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2389f5620f7..f20e6bee276 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added password reset functionality () - Ability to work with data on the fly (https://github.com/opencv/cvat/pull/2007) - Annotation in process outline color wheel () +- [Datumaro] CLI command for dataset equality comparison () ### Changed - UI models (like DEXTR) were redesigned to be more interactive () @@ -35,7 +36,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Ability to configure email verification for new users () - Link to django admin page from UI () - Notification message when users use wrong browser () -- [Datumaro] CLI command for dataset equality comparison () ### Changed - Shape coordinates are rounded to 2 digits in dumped annotations ()