diff --git a/datumaro/components/operations.py b/datumaro/components/operations.py index c11a528b6d..41d19ee7a0 100644 --- a/datumaro/components/operations.py +++ b/datumaro/components/operations.py @@ -37,6 +37,7 @@ Label, LabelCategories, MaskCategories, + Points, PointsCategories, ) from datumaro.components.cli_plugin import CliPlugin @@ -1856,7 +1857,7 @@ def _compare_categories(self, a, b): except AssertionError as e: errors.append({"type": "points", "message": str(e)}) - def _compare_annotations(self, a, b): + def _compare_annotations(self, a: Annotation, b: Annotation): ignored_fields = self.ignored_fields ignored_attrs = self.ignored_attrs @@ -1866,6 +1867,16 @@ def _compare_annotations(self, a, b): a_fields["attributes"] = filter_dict(a.attributes, ignored_attrs) b_fields["attributes"] = filter_dict(b.attributes, ignored_attrs) + if a.type == b.type == AnnotationType.skeleton and "elements" not in ignored_fields: + a_fields["elements"] = sorted( + filter(lambda p: p.visibility[0] != Points.Visibility.absent, a.elements), + key=lambda p: p.label if p.label is not None else -1, + ) + b_fields["elements"] = sorted( + filter(lambda p: p.visibility[0] != Points.Visibility.absent, b.elements), + key=lambda p: p.label if p.label is not None else -1, + ) + result = a.wrap(**a_fields) == b.wrap(**b_fields) return result diff --git a/datumaro/plugins/yolo_format/converter.py b/datumaro/plugins/yolo_format/converter.py index 8b812948b9..5c5f6e7417 100644 --- a/datumaro/plugins/yolo_format/converter.py +++ b/datumaro/plugins/yolo_format/converter.py @@ -194,6 +194,10 @@ def _export_media(self, item: DatasetItem, subset_img_dir: str) -> str: except Exception as e: self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset)) + def _save_annotation_file(self, annotation_path, yolo_annotation): + with open(annotation_path, "w", encoding="utf-8") as f: + f.write(yolo_annotation) + def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None: try: height, width = item.media.size @@ -208,8 +212,7 @@ def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None: annotation_path = osp.join(subset_dir, f"{item.id}{YoloPath.LABELS_EXT}") os.makedirs(osp.dirname(annotation_path), exist_ok=True) - with open(annotation_path, "w", encoding="utf-8") as f: - f.write(yolo_annotation) + self._save_annotation_file(annotation_path, yolo_annotation) except Exception as e: self._ctx.error_policy.report_item_error(e, item_id=(item.id, item.subset)) @@ -288,9 +291,9 @@ def __init__( super().__init__(extractor, save_dir, add_path_prefix=add_path_prefix, **kwargs) self._config_filename = config_file or YOLOv8Path.DEFAULT_CONFIG_FILE - def _export_item_annotation(self, item: DatasetItem, subset_dir: str) -> None: - if len(item.annotations) > 0: - super()._export_item_annotation(item, subset_dir) + def _save_annotation_file(self, annotation_path, yolo_annotation): + if yolo_annotation: + super()._save_annotation_file(annotation_path, yolo_annotation) @classmethod def build_cmdline_parser(cls, **kwargs): diff --git a/datumaro/util/test_utils.py b/datumaro/util/test_utils.py index 5e45aa1e76..77a0d524af 100644 --- a/datumaro/util/test_utils.py +++ b/datumaro/util/test_utils.py @@ -10,13 +10,14 @@ import unittest import unittest.mock import warnings +from copy import copy from enum import Enum, auto from glob import glob from typing import Any, Callable, Collection, Optional, Union from typing_extensions import Literal -from datumaro.components.annotation import AnnotationType +from datumaro.components.annotation import Annotation, AnnotationType, Points from datumaro.components.dataset import Dataset, IDataset from datumaro.components.media import Image, MultiframeImage, PointCloud from datumaro.util import filter_dict, find @@ -115,26 +116,34 @@ def compare_categories(test, expected, actual): IGNORE_ALL = "*" -def compare_annotations(expected, actual, ignored_attrs=None): - if not ignored_attrs: +def compare_annotations(expected: Annotation, actual: Annotation, ignored_attrs=None): + is_skeleton = expected.type == actual.type == AnnotationType.skeleton + if not ignored_attrs and not is_skeleton: return expected == actual - a_attr = expected.attributes - b_attr = actual.attributes + ignored_attrs = ignored_attrs or [] + + expected = copy(expected) + actual = copy(actual) if ignored_attrs != IGNORE_ALL: - expected.attributes = filter_dict(a_attr, exclude_keys=ignored_attrs) - actual.attributes = filter_dict(b_attr, exclude_keys=ignored_attrs) + expected.attributes = filter_dict(expected.attributes, exclude_keys=ignored_attrs) + actual.attributes = filter_dict(actual.attributes, exclude_keys=ignored_attrs) else: expected.attributes = {} actual.attributes = {} - r = expected == actual - - expected.attributes = a_attr - actual.attributes = b_attr + if is_skeleton: + expected.elements = sorted( + filter(lambda p: p.visibility[0] != Points.Visibility.absent, expected.elements), + key=lambda p: p.label if p.label is not None else -1, + ) + actual.elements = sorted( + filter(lambda p: p.visibility[0] != Points.Visibility.absent, actual.elements), + key=lambda p: p.label if p.label is not None else -1, + ) - return r + return expected == actual def compare_datasets( diff --git a/tests/test_diff.py b/tests/test_diff.py index e52a4a35e2..30a394a51e 100644 --- a/tests/test_diff.py +++ b/tests/test_diff.py @@ -2,7 +2,17 @@ import numpy as np -from datumaro.components.annotation import Bbox, Caption, Label, Mask, Points +from datumaro.components.annotation import ( + AnnotationType, + Bbox, + Caption, + Label, + LabelCategories, + Mask, + Points, + PointsCategories, + Skeleton, +) from datumaro.components.extractor import DEFAULT_SUBSET_NAME, DatasetItem from datumaro.components.media import Image from datumaro.components.operations import DistanceComparator, ExactComparator @@ -268,6 +278,146 @@ def test_annotation_comparison(self): self.assertEqual(2, len(unmatched), unmatched) self.assertEqual(0, len(errors), errors) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_skeleton_annotation_comparison(self): + categories = { + AnnotationType.label: LabelCategories.from_iterable( + [ + "skeleton", + ("point1", "skeleton"), + ("point2", "skeleton"), + ("point3", "skeleton"), + ] + ), + AnnotationType.points: PointsCategories.from_iterable( + [ + (0, ["point1", "point2", "point3"], set()), + ] + ), + } + a = Dataset.from_iterable( + [ + DatasetItem( + id=1, + annotations=[ + Skeleton( + [ + Points([0, 1], [Points.Visibility.visible]), + Points([1, 2], [Points.Visibility.hidden], label=2), + Points([2, 3], [Points.Visibility.absent], label=3), + ], + label=0, + ) + ], + ), + DatasetItem( + id=2, + annotations=[ + Skeleton( + [ + Points([4, 5], [Points.Visibility.visible], label=1), + Points([5, 6], [Points.Visibility.hidden], label=2), + Points([6, 7], [Points.Visibility.absent], label=3), + ], + label=0, + ) + ], + ), + DatasetItem( + id=3, + annotations=[ + Skeleton( + [ + Points([7, 8], [Points.Visibility.visible], label=1), + Points([8, 9], [Points.Visibility.hidden], label=2), + Points([9, 10], [Points.Visibility.absent], label=3), + ], + label=0, + ) + ], + ), + ], + categories=categories, + ) + + b = Dataset.from_iterable( + [ + DatasetItem( + id=1, + annotations=[ + Skeleton( + [ + Points([1, 2], [Points.Visibility.hidden], label=2), + Points([0, 1], [Points.Visibility.visible]), + ], + label=0, + ) + ], + ), # matched, even though absent point is removed and order differs + DatasetItem( + id=2, + annotations=[ + Skeleton( + [ + Points([4, 5], [Points.Visibility.visible], label=1), + Points([5, 6], [Points.Visibility.hidden], label=2), + Points([6, 8], [Points.Visibility.absent], label=3), + ], + label=0, + ) + ], + ), # matched, even though absent point has different coordinates + DatasetItem( + id=3, + annotations=[ + Skeleton( + [ + Points([7, 8], [Points.Visibility.visible], label=1), + Points([8, 10], [Points.Visibility.hidden], label=2), + Points([9, 10], [Points.Visibility.absent], label=3), + ], + label=0, + ) + ], + ), # not matched, not-absent point has changed coordinates + ], + categories=categories, + ) + + comp = ExactComparator() + _, unmatched, _, _, _ = comp.compare_datasets(a, b) + + assert unmatched == [ + { + "item": ("3", "default"), + "source": "a", + "ann": repr( + Skeleton( + [ + Points([7, 8], [2], label=1), + Points([8, 9], [1], label=2), + Points([9, 10], [0], label=3), + ], + label=0, + ) + ), + }, + { + "item": ("3", "default"), + "source": "b", + "ann": repr( + Skeleton( + [ + Points([7, 8], [2], label=1), + Points([8, 10], [1], label=2), + Points([9, 10], [0], label=3), + ], + label=0, + ) + ), + }, + ] + @mark_requirement(Requirements.DATUM_GENERAL_REQ) def test_image_comparison(self): a = Dataset.from_iterable( diff --git a/tests/unit/data_formats/test_yolo_format.py b/tests/unit/data_formats/test_yolo_format.py index 0eba149913..3cfb8d8344 100644 --- a/tests/unit/data_formats/test_yolo_format.py +++ b/tests/unit/data_formats/test_yolo_format.py @@ -18,6 +18,7 @@ AnnotationType, Bbox, LabelCategories, + Mask, Points, PointsCategories, Polygon, @@ -466,14 +467,37 @@ def test_can_save_and_load_without_path_prefix(self, test_dir): self.compare_datasets(source_dataset, parsed_dataset) - @mark_requirement(Requirements.DATUM_609) - def test_can_save_and_load_without_annotations(self, test_dir): - source_dataset = self._generate_random_dataset([{"annotations": 0}]) + @mark_requirement(Requirements.DATUM_GENERAL_REQ) + def test_can_save_without_creating_annotation_file_and_load(self, test_dir): + categories = self._generate_random_dataset([]).categories() + source_dataset = Dataset.from_iterable( + [ + DatasetItem( + id=1, + subset="train", + media=Image(data=np.ones((8, 8, 3))), + annotations=[ + # Mask annotation is not supported by yolo8 formats, so should be omitted + Mask(np.array([[0, 1, 1, 1, 0]]), label=0), + ], + ) + ], + categories=categories, + ) + expected_dataset = Dataset.from_iterable( + [ + DatasetItem( + id=1, + subset="train", + media=Image(data=np.ones((8, 8, 3))), + ) + ], + categories=categories, + ) self.CONVERTER.convert(source_dataset, test_dir, save_media=True) - assert os.listdir(osp.join(test_dir, "labels", "train")) == [] parsed_dataset = Dataset.import_from(test_dir, self.IMPORTER.NAME) - self.compare_datasets(source_dataset, parsed_dataset) + self.compare_datasets(expected_dataset, parsed_dataset) def _check_inplace_save_writes_only_updated_data(self, test_dir, expected): assert set(os.listdir(osp.join(test_dir, "images", "train"))) == {