Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ignoring point order in skeleton annotations when comparing annotations #57

Merged
merged 8 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion datumaro/components/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Label,
LabelCategories,
MaskCategories,
Points,
PointsCategories,
)
from datumaro.components.cli_plugin import CliPlugin
Expand Down Expand Up @@ -1856,7 +1857,7 @@ def _compare_categories(self, a, b):
except AssertionError as e:
errors.append({"type": "points", "message": str(e)})

def _compare_annotations(self, a, b):
def _compare_annotations(self, a: Annotation, b: Annotation):
ignored_fields = self.ignored_fields
ignored_attrs = self.ignored_attrs

Expand All @@ -1866,6 +1867,16 @@ def _compare_annotations(self, a, b):
a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs)
b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs)

if a.type == b.type == AnnotationType.skeleton and "elements" not in ignored_fields:
a_fields["elements"] = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, a.elements),
key=lambda p: p.label if p.label is not None else -1,
)
b_fields["elements"] = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, b.elements),
key=lambda p: p.label if p.label is not None else -1,
)

result = a.wrap(**a_fields) == b.wrap(**b_fields)

return result
Expand Down
13 changes: 8 additions & 5 deletions datumaro/plugins/yolo_format/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ def _export_media(self, item: DatasetItem, subset_img_dir: str) -> str:
except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))

def _save_annotation_file(self, annotation_path, yolo_annotation):
with open(annotation_path, "w", encoding="utf-8") as f:
f.write(yolo_annotation)

def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
try:
height, width = item.media.size
Expand All @@ -208,8 +212,7 @@ def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
annotation_path = osp.join(subset_dir, f"{item.id}{YoloPath.LABELS_EXT}")
os.makedirs(osp.dirname(annotation_path), exist_ok=True)

with open(annotation_path, "w", encoding="utf-8") as f:
f.write(yolo_annotation)
self._save_annotation_file(annotation_path, yolo_annotation)

except Exception as e:
self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset))
Expand Down Expand Up @@ -288,9 +291,9 @@ def __init__(
super().__init__(extractor, save_dir, add_path_prefix=add_path_prefix, **kwargs)
self._config_filename = config_file or YOLOv8Path.DEFAULT_CONFIG_FILE

def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None:
if len(item.annotations) > 0:
super()._export_item_annotation(item, subset_dir)
def _save_annotation_file(self, annotation_path, yolo_annotation):
if yolo_annotation:
super()._save_annotation_file(annotation_path, yolo_annotation)

@classmethod
def build_cmdline_parser(cls, **kwargs):
Expand Down
33 changes: 21 additions & 12 deletions datumaro/util/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import unittest
import unittest.mock
import warnings
from copy import copy
from enum import Enum, auto
from glob import glob
from typing import Any, Callable, Collection, Optional, Union

from typing_extensions import Literal

from datumaro.components.annotation import AnnotationType
from datumaro.components.annotation import Annotation, AnnotationType, Points
from datumaro.components.dataset import Dataset, IDataset
from datumaro.components.media import Image, MultiframeImage, PointCloud
from datumaro.util import filter_dict, find
Expand Down Expand Up @@ -115,26 +116,34 @@ def compare_categories(test, expected, actual):
IGNORE_ALL = "*"


def compare_annotations(expected, actual, ignored_attrs=None):
if not ignored_attrs:
def compare_annotations(expected: Annotation, actual: Annotation, ignored_attrs=None):
is_skeleton = expected.type == actual.type == AnnotationType.skeleton
if not ignored_attrs and not is_skeleton:
return expected == actual

a_attr = expected.attributes
b_attr = actual.attributes
ignored_attrs = ignored_attrs or []

expected = copy(expected)
actual = copy(actual)

if ignored_attrs != IGNORE_ALL:
expected.attributes = filter_dict(a_attr, exclude_keys=ignored_attrs)
actual.attributes = filter_dict(b_attr, exclude_keys=ignored_attrs)
expected.attributes = filter_dict(expected.attributes, exclude_keys=ignored_attrs)
actual.attributes = filter_dict(actual.attributes, exclude_keys=ignored_attrs)
else:
expected.attributes = {}
actual.attributes = {}

r = expected == actual

expected.attributes = a_attr
actual.attributes = b_attr
if is_skeleton:
expected.elements = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, expected.elements),
key=lambda p: p.label if p.label is not None else -1,
)
actual.elements = sorted(
filter(lambda p: p.visibility[0] != Points.Visibility.absent, actual.elements),
key=lambda p: p.label if p.label is not None else -1,
)

return r
return expected == actual


def compare_datasets(
Expand Down
152 changes: 151 additions & 1 deletion tests/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@

import numpy as np

from datumaro.components.annotation import Bbox, Caption, Label, Mask, Points
from datumaro.components.annotation import (
AnnotationType,
Bbox,
Caption,
Label,
LabelCategories,
Mask,
Points,
PointsCategories,
Skeleton,
)
from datumaro.components.extractor import DEFAULT_SUBSET_NAME, DatasetItem
from datumaro.components.media import Image
from datumaro.components.operations import DistanceComparator, ExactComparator
Expand Down Expand Up @@ -268,6 +278,146 @@ def test_annotation_comparison(self):
self.assertEqual(2, len(unmatched), unmatched)
self.assertEqual(0, len(errors), errors)

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_skeleton_annotation_comparison(self):
categories = {
AnnotationType.label: LabelCategories.from_iterable(
[
"skeleton",
("point1", "skeleton"),
("point2", "skeleton"),
("point3", "skeleton"),
]
),
AnnotationType.points: PointsCategories.from_iterable(
[
(0, ["point1", "point2", "point3"], set()),
]
),
}
a = Dataset.from_iterable(
[
DatasetItem(
id=1,
annotations=[
Skeleton(
[
Points([0, 1], [Points.Visibility.visible]),
Points([1, 2], [Points.Visibility.hidden], label=2),
Points([2, 3], [Points.Visibility.absent], label=3),
],
label=0,
)
],
),
DatasetItem(
id=2,
annotations=[
Skeleton(
[
Points([4, 5], [Points.Visibility.visible], label=1),
Points([5, 6], [Points.Visibility.hidden], label=2),
Points([6, 7], [Points.Visibility.absent], label=3),
],
label=0,
)
],
),
DatasetItem(
id=3,
annotations=[
Skeleton(
[
Points([7, 8], [Points.Visibility.visible], label=1),
Points([8, 9], [Points.Visibility.hidden], label=2),
Points([9, 10], [Points.Visibility.absent], label=3),
],
label=0,
)
],
),
],
categories=categories,
)

b = Dataset.from_iterable(
[
DatasetItem(
id=1,
annotations=[
Skeleton(
[
Points([1, 2], [Points.Visibility.hidden], label=2),
Points([0, 1], [Points.Visibility.visible]),
],
label=0,
)
],
), # matched, even though absent point is removed and order differs
DatasetItem(
id=2,
annotations=[
Skeleton(
[
Points([4, 5], [Points.Visibility.visible], label=1),
Points([5, 6], [Points.Visibility.hidden], label=2),
Points([6, 8], [Points.Visibility.absent], label=3),
],
label=0,
)
],
), # matched, even though absent point has different coordinates
DatasetItem(
id=3,
annotations=[
Skeleton(
[
Points([7, 8], [Points.Visibility.visible], label=1),
Points([8, 10], [Points.Visibility.hidden], label=2),
Points([9, 10], [Points.Visibility.absent], label=3),
],
label=0,
)
],
), # not matched, not-absent point has changed coordinates
],
categories=categories,
)

comp = ExactComparator()
_, unmatched, _, _, _ = comp.compare_datasets(a, b)

assert unmatched == [
{
"item": ("3", "default"),
"source": "a",
"ann": repr(
Skeleton(
[
Points([7, 8], [2], label=1),
Points([8, 9], [1], label=2),
Points([9, 10], [0], label=3),
],
label=0,
)
),
},
{
"item": ("3", "default"),
"source": "b",
"ann": repr(
Skeleton(
[
Points([7, 8], [2], label=1),
Points([8, 10], [1], label=2),
Points([9, 10], [0], label=3),
],
label=0,
)
),
},
]

@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_image_comparison(self):
a = Dataset.from_iterable(
Expand Down
34 changes: 29 additions & 5 deletions tests/unit/data_formats/test_yolo_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AnnotationType,
Bbox,
LabelCategories,
Mask,
Points,
PointsCategories,
Polygon,
Expand Down Expand Up @@ -466,14 +467,37 @@ def test_can_save_and_load_without_path_prefix(self, test_dir):

self.compare_datasets(source_dataset, parsed_dataset)

@mark_requirement(Requirements.DATUM_609)
def test_can_save_and_load_without_annotations(self, test_dir):
source_dataset = self._generate_random_dataset([{"annotations": 0}])
@mark_requirement(Requirements.DATUM_GENERAL_REQ)
def test_can_save_without_creating_annotation_file_and_load(self, test_dir):
categories = self._generate_random_dataset([]).categories()
source_dataset = Dataset.from_iterable(
[
DatasetItem(
id=1,
subset="train",
media=Image(data=np.ones((8, 8, 3))),
annotations=[
# Mask annotation is not supported by yolo8 formats, so should be omitted
Mask(np.array([[0, 1, 1, 1, 0]]), label=0),
],
)
],
categories=categories,
)
expected_dataset = Dataset.from_iterable(
[
DatasetItem(
id=1,
subset="train",
media=Image(data=np.ones((8, 8, 3))),
)
],
categories=categories,
)
self.CONVERTER.convert(source_dataset, test_dir, save_media=True)

assert os.listdir(osp.join(test_dir, "labels", "train")) == []
parsed_dataset = Dataset.import_from(test_dir, self.IMPORTER.NAME)
self.compare_datasets(source_dataset, parsed_dataset)
self.compare_datasets(expected_dataset, parsed_dataset)

def _check_inplace_save_writes_only_updated_data(self, test_dir, expected):
assert set(os.listdir(osp.join(test_dir, "images", "train"))) == {
Expand Down
Loading