Skip to content

Commit

Permalink
ignoring point order in skeleton annotations when comparing annotatio…
Browse files Browse the repository at this point in the history
…ns (#57)

* when comparing datasets, ignoring point order in skeleton annotations and existence of absent points

* do not create empty files on export

* fix tests

* fix tests again

* if label points are missing, use -1 as a label for sorting purposes

* fixes

* fix

* fixing code duplication
  • Loading branch information
Eldies authored Aug 28, 2024
1 parent 125840f commit 393cb66
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 24 deletions.
13 changes: 12 additions & 1 deletion datumaro/components/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Label,
LabelCategories,
MaskCategories,
Points,
PointsCategories,
)
from datumaro.components.cli_plugin import CliPlugin
Expand Down Expand Up @@ -1856,7 +1857,7 @@ def _compare_categories(self, a, b):
except AssertionError as e:
errors.append({"type": "points", "message": str(e)})

def _compare_annotations(self, a, b):
def _compare_annotations(self, a: Annotation, b: Annotation):
ignored_fields = self.ignored_fields
ignored_attrs = self.ignored_attrs

Expand All @@ -1866,6 +1867,16 @@ def _compare_annotations(self, a, b):
a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs)
b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs)

if a.type == b.type == AnnotationType.skeleton and "elements" not in ignored_fields:
a_fields["elements"] = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, a.elements),
key=lambda p: p.label if p.label is not None else -1,
)
b_fields["elements"] = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, b.elements),
key=lambda p: p.label if p.label is not None else -1,
)

result = a.wrap(**a_fields) == b.wrap(**b_fields)

return result
Expand Down
13 changes: 8 additions & 5 deletions datumaro/plugins/yolo_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def _export_media(self, item: DatasetItem, subset_img_dir: str) -> str:
except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))

def _save_annotation_file(self, annotation_path, yolo_annotation):
with open(annotation_path, "w", encoding="utf-8") as f:
f.write(yolo_annotation)

def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
try:
height, width = item.media.size
Expand All @@ -208,8 +212,7 @@ def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
annotation_path = osp.join(subset_dir, f"{item.id}{YoloPath.LABELS_EXT}")
os.makedirs(osp.dirname(annotation_path), exist_ok=True)

with open(annotation_path, "w", encoding="utf-8") as f:
f.write(yolo_annotation)
self._save_annotation_file(annotation_path, yolo_annotation)

except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))
Expand Down Expand Up @@ -288,9 +291,9 @@ def __init__(
super().__init__(extractor, save_dir, add_path_prefix=add_path_prefix, **kwargs)
self._config_filename = config_file or YOLOv8Path.DEFAULT_CONFIG_FILE

def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
if len(item.annotations) > 0:
super()._export_item_annotation(item, subset_dir)
def _save_annotation_file(self, annotation_path, yolo_annotation):
if yolo_annotation:
super()._save_annotation_file(annotation_path, yolo_annotation)

@classmethod
def build_cmdline_parser(cls, **kwargs):
Expand Down
33 changes: 21 additions & 12 deletions datumaro/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import unittest
import unittest.mock
import warnings
from copy import copy
from enum import Enum, auto
from glob import glob
from typing import Any, Callable, Collection, Optional, Union

from typing_extensions import Literal

from datumaro.components.annotation import AnnotationType
from datumaro.components.annotation import Annotation, AnnotationType, Points
from datumaro.components.dataset import Dataset, IDataset
from datumaro.components.media import Image, MultiframeImage, PointCloud
from datumaro.util import filter_dict, find
Expand Down Expand Up @@ -115,26 +116,34 @@ def compare_categories(test, expected, actual):
IGNORE_ALL = "*"


def compare_annotations(expected, actual, ignored_attrs=None):
if not ignored_attrs:
def compare_annotations(expected: Annotation, actual: Annotation, ignored_attrs=None):
is_skeleton = expected.type == actual.type == AnnotationType.skeleton
if not ignored_attrs and not is_skeleton:
return expected == actual

a_attr = expected.attributes
b_attr = actual.attributes
ignored_attrs = ignored_attrs or []

expected = copy(expected)
actual = copy(actual)

if ignored_attrs != IGNORE_ALL:
expected.attributes = filter_dict(a_attr, exclude_keys=ignored_attrs)
actual.attributes = filter_dict(b_attr, exclude_keys=ignored_attrs)
expected.attributes = filter_dict(expected.attributes, exclude_keys=ignored_attrs)
actual.attributes = filter_dict(actual.attributes, exclude_keys=ignored_attrs)
else:
expected.attributes = {}
actual.attributes = {}

r = expected == actual

expected.attributes = a_attr
actual.attributes = b_attr
if is_skeleton:
expected.elements = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, expected.elements),
key=lambda p: p.label if p.label is not None else -1,
)
actual.elements = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, actual.elements),
key=lambda p: p.label if p.label is not None else -1,
)

return r
return expected == actual


def compare_datasets(
Expand Down
152 changes: 151 additions & 1 deletion tests/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

import numpy as np

from datumaro.components.annotation import Bbox, Caption, Label, Mask, Points
from datumaro.components.annotation import (
AnnotationType,
Bbox,
Caption,
Label,
LabelCategories,
Mask,
Points,
PointsCategories,
Skeleton,
)
from datumaro.components.extractor import DEFAULT_SUBSET_NAME, DatasetItem
from datumaro.components.media import Image
from datumaro.components.operations import DistanceComparator, ExactComparator
Expand Down Expand Up @@ -268,6 +278,146 @@ def test_annotation_comparison(self):
self.assertEqual(2, len(unmatched), unmatched)
self.assertEqual(0, len(errors), errors)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_skeleton_annotation_comparison(self):
categories = {
AnnotationType.label: LabelCategories.from_iterable(
[
"skeleton",
("point1", "skeleton"),
("point2", "skeleton"),
("point3", "skeleton"),
]
),
AnnotationType.points: PointsCategories.from_iterable(
[
(0, ["point1", "point2", "point3"], set()),
]
),
}
a = Dataset.from_iterable(
[
DatasetItem(
id=1,
annotations=[
Skeleton(
[
Points([0, 1], [Points.Visibility.visible]),
Points([1, 2], [Points.Visibility.hidden], label=2),
Points([2, 3], [Points.Visibility.absent], label=3),
],
label=0,
)
],
),
DatasetItem(
id=2,
annotations=[
Skeleton(
[
Points([4, 5], [Points.Visibility.visible], label=1),
Points([5, 6], [Points.Visibility.hidden], label=2),
Points([6, 7], [Points.Visibility.absent], label=3),
],
label=0,
)
],
),
DatasetItem(
id=3,
annotations=[
Skeleton(
[
Points([7, 8], [Points.Visibility.visible], label=1),
Points([8, 9], [Points.Visibility.hidden], label=2),
Points([9, 10], [Points.Visibility.absent], label=3),
],
label=0,
)
],
),
],
categories=categories,
)

b = Dataset.from_iterable(
[
DatasetItem(
id=1,
annotations=[
Skeleton(
[
Points([1, 2], [Points.Visibility.hidden], label=2),
Points([0, 1], [Points.Visibility.visible]),
],
label=0,
)
],
), # matched, even though absent point is removed and order differs
DatasetItem(
id=2,
annotations=[
Skeleton(
[
Points([4, 5], [Points.Visibility.visible], label=1),
Points([5, 6], [Points.Visibility.hidden], label=2),
Points([6, 8], [Points.Visibility.absent], label=3),
],
label=0,
)
],
), # matched, even though absent point has different coordinates
DatasetItem(
id=3,
annotations=[
Skeleton(
[
Points([7, 8], [Points.Visibility.visible], label=1),
Points([8, 10], [Points.Visibility.hidden], label=2),
Points([9, 10], [Points.Visibility.absent], label=3),
],
label=0,
)
],
), # not matched, not-absent point has changed coordinates
],
categories=categories,
)

comp = ExactComparator()
_, unmatched, _, _, _ = comp.compare_datasets(a, b)

assert unmatched == [
{
"item": ("3", "default"),
"source": "a",
"ann": repr(
Skeleton(
[
Points([7, 8], [2], label=1),
Points([8, 9], [1], label=2),
Points([9, 10], [0], label=3),
],
label=0,
)
),
},
{
"item": ("3", "default"),
"source": "b",
"ann": repr(
Skeleton(
[
Points([7, 8], [2], label=1),
Points([8, 10], [1], label=2),
Points([9, 10], [0], label=3),
],
label=0,
)
),
},
]

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_image_comparison(self):
a = Dataset.from_iterable(
Expand Down
34 changes: 29 additions & 5 deletions tests/unit/data_formats/test_yolo_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AnnotationType,
Bbox,
LabelCategories,
Mask,
Points,
PointsCategories,
Polygon,
Expand Down Expand Up @@ -466,14 +467,37 @@ def test_can_save_and_load_without_path_prefix(self, test_dir):

self.compare_datasets(source_dataset, parsed_dataset)

@mark_requirement(Requirements.DATUM_609)
def test_can_save_and_load_without_annotations(self, test_dir):
source_dataset = self._generate_random_dataset([{"annotations": 0}])
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_save_without_creating_annotation_file_and_load(self, test_dir):
categories = self._generate_random_dataset([]).categories()
source_dataset = Dataset.from_iterable(
[
DatasetItem(
id=1,
subset="train",
media=Image(data=np.ones((8, 8, 3))),
annotations=[
# Mask annotation is not supported by yolo8 formats, so should be omitted
Mask(np.array([[0, 1, 1, 1, 0]]), label=0),
],
)
],
categories=categories,
)
expected_dataset = Dataset.from_iterable(
[
DatasetItem(
id=1,
subset="train",
media=Image(data=np.ones((8, 8, 3))),
)
],
categories=categories,
)
self.CONVERTER.convert(source_dataset, test_dir, save_media=True)

assert os.listdir(osp.join(test_dir, "labels", "train")) == []
parsed_dataset = Dataset.import_from(test_dir, self.IMPORTER.NAME)
self.compare_datasets(source_dataset, parsed_dataset)
self.compare_datasets(expected_dataset, parsed_dataset)

def _check_inplace_save_writes_only_updated_data(self, test_dir, expected):
assert set(os.listdir(osp.join(test_dir, "images", "train"))) == {
Expand Down

0 comments on commit 393cb66

Please sign in to comment.