From 3a746f51c0bd0d05052d85a0ce61e40056231027 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 18 Jul 2024 19:04:12 +0300 Subject: [PATCH 1/4] Add ability to import annotations into a datasource --- dagshub/data_engine/annotation/importer.py | 295 ++++++++++++++++++ dagshub/data_engine/model/datasource.py | 99 +++++- setup.py | 4 + .../annotation_import/test_load_location.py | 38 +++ .../annotation_import/test_path_remapping.py | 74 +++++ tests/mocks/repo_api.py | 2 + 6 files changed, 507 insertions(+), 5 deletions(-) create mode 100644 dagshub/data_engine/annotation/importer.py create mode 100644 tests/data_engine/annotation_import/test_load_location.py create mode 100644 tests/data_engine/annotation_import/test_path_remapping.py diff --git a/dagshub/data_engine/annotation/importer.py b/dagshub/data_engine/annotation/importer.py new file mode 100644 index 00000000..5a04e09e --- /dev/null +++ b/dagshub/data_engine/annotation/importer.py @@ -0,0 +1,295 @@ +from difflib import SequenceMatcher +from pathlib import Path, PurePosixPath +from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable + +from dagshub_annotation_converter.converters.cvat import load_cvat_from_zip +from dagshub_annotation_converter.converters.yolo import load_yolo_from_fs +from dagshub_annotation_converter.formats.label_studio.task import LabelStudioTask +from dagshub_annotation_converter.formats.yolo import YoloContext +from dagshub_annotation_converter.ir.image.annotations.base import IRAnnotationBase + +from dagshub.common.api import UserAPI +from dagshub.common.api.repo import PathNotFoundError +from dagshub.common.helpers import log_message + +if TYPE_CHECKING: + from dagshub.data_engine.model.datasource import Datasource + +AnnotationType = Literal["yolo", "cvat"] +AnnotationLocation = Literal["repo", "disk"] + + +class AnnotationsNotFoundError(Exception): + def __init__(self, path): + super().__init__(f'Annotations not found at path "{path}" in neither disk or repository.') + + +class CannotRemapPathError(Exception): + def __init__(self, a_path, b_path): + super().__init__(f"Cannot map from path {a_path} to path {b_path}") + + +class AnnotationImporter: + """ + Class for importing annotations into a datasource from different formats. + """ + + def __init__( + self, + ds: "Datasource", + annotations_type: AnnotationType, + annotations_file: Union[str, Path], + load_from: Optional[AnnotationLocation] = None, + **kwargs, + ): + self.ds = ds.__deepcopy__() + self.ds.clear_query() + self.annotations_type = annotations_type + self.annotations_file = Path(annotations_file) + self.load_from = load_from if load_from is not None else self.determine_load_location(ds, annotations_file) + self.additional_args = kwargs + + if self.annotations_type == "yolo": + if "yolo_type" not in kwargs: + raise ValueError( + "YOLO type must be provided in the additional arguments. " + 'Add `yolo_type="bbox"|"segmentation"|pose"` to the arguments.' + ) + + def import_annotations(self) -> Mapping[str, Sequence[IRAnnotationBase]]: + # Double check that the annotation file exists + if self.load_from == "disk": + if not self.annotations_file.exists(): + raise AnnotationsNotFoundError(self.annotations_file) + elif self.load_from == "repo": + try: + self.ds.source.repoApi.list_path(self.annotations_file.as_posix()) + except PathNotFoundError: + raise AnnotationsNotFoundError(self.annotations_file) + + annotations_file = self.annotations_file + + with TemporaryDirectory() as tmp_dir: + tmp_dir_path = Path(tmp_dir) + if self.load_from == "repo": + self.download_annotations(tmp_dir_path) + annotations_file = tmp_dir_path / annotations_file.name + + # Convert annotations + log_message("Loading annotations...") + annotation_dict: Mapping[str, Sequence[IRAnnotationBase]] + if self.annotations_type == "yolo": + annotation_dict, _ = load_yolo_from_fs( + annotation_type=self.additional_args["yolo_type"], meta_file=annotations_file + ) + elif self.annotations_type == "cvat": + annotation_dict = load_cvat_from_zip(annotations_file) + + return annotation_dict + + def download_annotations(self, dest_dir: Path): + log_message("Downloading annotations from repository") + repoApi = self.ds.source.repoApi + if self.annotations_type == "cvat": + # Download just the annotation file + repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) + elif self.annotations_type == "yolo": + # Download the dataset .yaml file and the images + annotations + # Download the file + repoApi.download(self.annotations_file.as_posix(), dest_dir, keep_source_prefix=True) + # Get the YOLO Context from the downloaded file + meta_file = dest_dir / self.annotations_file + context = YoloContext.from_yaml_file(meta_file, annotation_type=self.additional_args["yolo_type"]) + # Download the annotation data + assert context.path is not None + repoApi.download(self.annotations_file.parent / context.path, dest_dir, keep_source_prefix=True) + + @staticmethod + def determine_load_location(ds: "Datasource", annotations_path: Union[str, Path]) -> AnnotationLocation: + # Local files take priority + if Path(annotations_path).exists(): + return "disk" + + # Try to find it in the repo otherwise + try: + files = ds.source.repoApi.list_path(Path(annotations_path).as_posix()) + if len(files) > 0: + return "repo" + except PathNotFoundError: + pass + + # TODO: handle repo bucket too + + raise AnnotationsNotFoundError(annotations_path) + + def remap_annotations( + self, + annotations: Mapping[str, Sequence[IRAnnotationBase]], + remap_func: Optional[Callable[[str], Optional[str]]] = None, + ) -> Mapping[str, Sequence[IRAnnotationBase]]: + """ + Remaps the filenames in the annotations to the datasource's data points. + + Args: + annotations: Annotations to remap + remap_func: Function that maps from an annotation path to a datapoint path. \ + If None, we try to guess it by getting a datapoint and remapping that path + """ + if remap_func is None: + first_ann = list(annotations.keys())[0] + first_ann_filename = PurePosixPath(first_ann).name + queried = self.ds["path"].endswith(first_ann_filename).select("size").all() + dp_paths = [dp.path for dp in queried] + remap_func = self.guess_annotation_filename_remapping(first_ann, dp_paths) + + remapped = {} + + for filename, anns in annotations.items(): + new_filename = remap_func(filename) + if new_filename is None: + log_message( + f'Skipping annotation with filename "{filename}" because it could not be mapped to a datapoint' + ) + continue + for ann in anns: + assert ann.filename is not None + ann.filename = remap_func(ann.filename) + remapped[new_filename] = anns + + return remapped + + @staticmethod + def guess_annotation_filename_remapping( + annotation_path: str, datapoint_paths: list[str] + ) -> Callable[[str], Optional[str]]: + """ + Guesses the remapping function from the annotations to the data points. + + Args: + annotation_path: path of an existing annotations + datapoint_paths: paths of the data points in the datasource that end with the filename of this annotation + """ + + if len(datapoint_paths) == 0: + raise ValueError(f"No datapoints found that match the annotation path {annotation_path}") + + dp_path = datapoint_paths[0] + + if len(datapoint_paths) > 1: + # TODO: Maybe prompt user to choose a fitting datapoint (ordered by similarity) + dp_path = AnnotationImporter.get_best_fit_datapoint_path(annotation_path, datapoint_paths) + log_message(f'Multiple datapoints found for annotation path "{annotation_path}". Using "{dp_path}"') + + return AnnotationImporter.generate_path_map_func(annotation_path, dp_path) + + @staticmethod + def generate_path_map_func(ann_path: str, dp_path: str) -> Callable[[str], Optional[str]]: + ann_path_posix = PurePosixPath(ann_path) + dp_path_posix = PurePosixPath(dp_path) + + matcher = SequenceMatcher( + None, + ann_path_posix.parts, + dp_path_posix.parts, + ) + diff = matcher.get_matching_blocks() + + # Make sure that both sequences have the same end, get the common part. + # Then the rest is going to be the prefix that is either added or subtracted. + + # We need there to be only one match that is at the very end, otherwise we throw an error + if len(diff) != 2: + raise CannotRemapPathError(ann_path, dp_path) + + match = diff[0] + # Make sure that the match goes until the end + if match.a + match.size != len(ann_path_posix.parts) or match.b + match.size != len(dp_path_posix.parts): + raise CannotRemapPathError(ann_path, dp_path) + # ONE of the paths need to go until the start + if match.a != 0 and match.b != 0: + raise CannotRemapPathError(ann_path, dp_path) + + # If the match is total, just return identity + if match.a == 0 and match.b == 0: + + def identity_func(x: str) -> str: + return x.replace(ann_path, dp_path) + + return identity_func + + # The function that maps ends up being: + # - Get the common part of the path + # - Either remove the remainder, or add the prefix, depending on which is longer + + if match.b > match.a: + # Add a prefix + prefix = dp_path_posix.parts[match.a : match.b] + + def add_prefix(x: str) -> Optional[str]: + return PurePosixPath(*prefix, x).as_posix() + + return add_prefix + + else: + # Remove the prefix + def remove_prefix(x: str) -> Optional[str]: + p = PurePosixPath(x) + if len(p.parts) <= match.a: + return None + return PurePosixPath(*p.parts[match.a :]).as_posix() + + return remove_prefix + + @staticmethod + def get_best_fit_datapoint_path(ann_path: str, datapoint_paths: list[str]) -> str: + """ + Get the datapoint path that is the closest to the annotation path. + + Args: + ann_path: path of an annotation + datapoint_paths: paths of the data points in the datasource that end with the filename of this annotation + """ + best_match: Optional[str] = None + best_match_length: Optional[int] = None + + for dp_path in datapoint_paths: + ann_path_posix = PurePosixPath(ann_path) + dp_path_posix = PurePosixPath(dp_path) + + matcher = SequenceMatcher( + None, + ann_path_posix.parts, + dp_path_posix.parts, + ) + diff = matcher.get_matching_blocks() + + if len(diff) != 2: # Has multiple matches - bad + continue + match = diff[0] + if match.a != 0 and match.b != 0: + continue + + # Exact match - perfect! + if match.a == 0 and match.b == 0: + return dp_path + + if best_match_length is None or match.size > best_match_length: + best_match = dp_path + best_match_length = match.size + if best_match is None: + raise ValueError(f"No good match found for annotation path {ann_path} in the datasource.") + return best_match + + def convert_to_ls_tasks(self, annotations: Mapping[str, Sequence[IRAnnotationBase]]) -> Mapping[str, bytes]: + """ + Converts the annotations to Label Studio tasks. + """ + current_user_id = UserAPI.get_current_user(self.ds.source.repoApi.host).user_id + tasks = {} + for filename, anns in annotations.items(): + t = LabelStudioTask(user_id=current_user_id) + t.data["image"] = self.ds.source.raw_path(filename) + t.add_ir_annotations(anns) + tasks[filename] = t.model_dump_json().encode("utf-8") + return tasks diff --git a/dagshub/data_engine/model/datasource.py b/dagshub/data_engine/model/datasource.py index d2ca7299..058a31b5 100644 --- a/dagshub/data_engine/model/datasource.py +++ b/dagshub/data_engine/model/datasource.py @@ -12,7 +12,7 @@ from dataclasses import dataclass, field from os import PathLike from pathlib import Path -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, Set, ContextManager, Tuple, Literal +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union, Set, ContextManager, Tuple, Literal, Callable import rich.progress @@ -26,6 +26,7 @@ from dagshub.common.helpers import prompt_user, http_request, log_message from dagshub.common.rich_util import get_rich_progress from dagshub.common.util import lazy_load, multi_urljoin, to_timestamp, exclude_if_none +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationType, AnnotationLocation from dagshub.data_engine.client.models import ( PreprocessingStatus, MetadataFieldSchema, @@ -76,8 +77,9 @@ class DatapointMetadataUpdateEntry(DataClassJsonMixin): value: str valueType: MetadataFieldType = field(metadata=config(encoder=lambda val: val.value)) allowMultiple: bool = False - timeZone: Optional[str] = field(default=None, - metadata=config(exclude=exclude_if_none, letter_case=LetterCase.CAMEL)) + timeZone: Optional[str] = field( + default=None, metadata=config(exclude=exclude_if_none, letter_case=LetterCase.CAMEL) + ) @dataclass @@ -1368,6 +1370,92 @@ def _test_not_comparing_other_ds(other): if type(other) is Datasource: raise DatasetFieldComparisonError() + def import_annotations_from_files( + self, + annotation_type: AnnotationType, + path: Union[str, Path], + field: str = "imported_annotation", + load_from: Optional[AnnotationLocation] = None, + remapping_function: Optional[Callable[[str], str]] = None, + **kwargs, + ): + """ + Imports annotations into the datasource from files + + The annotations will be downloaded and converted into Label Studio tasks, + that are then uploaded into the specified fields. + + If the annotations are stored in a repo and not locally, they are downloaded to a temporary directory. + + Caveats: + - YOLO: + - Images need to also be downloaded to get their dimensions. + - The .YAML file needs to have the ``path`` argument set to the relative path to the data. \ + We're using that to download the files + - You have to specify the ``yolo_type`` kwarg with the type of annotation to import + + Args: + annotation_type: Type of annotations to import. Possible values are ``yolo`` and ``cvat`` + path: If YOLO - path to the .yaml file, if CVAT - path to the .zip file. \ + Can be either on disk or in repository + field: Which field to upload the annotations into. \ + If it's an existing field, it has to be a blob field, \ + and it will have the annotations flag set afterwards. + load_from: Force specify where to get the files from. \ + By default, we're trying to load files from the disk first, and then repository. + If this is specified, then that check is being skipped and \ + we'll try to download from the specified location. + remapping_function: Function that maps from a path of the annotation to the path of the datapoint. \ + If None, we try to make a best guess based on the first imported annotation. \ + This might fail, if there is no matching datapoint in the datasource for some annotations \ + or if the paths are wildly different. + + Keyword Args: + yolo_type: Type of YOLO annotations to import. Either ``bbox``, ``segmentation`` or ``pose``. + + Example to import segmentation annotations into an ``imported_annotations`` field, + using YOLO information from an ``annotations.yaml`` file (can be local, or in the repo):: + + ds.import_annotations_from_files( + annotation_type="yolo", + annotations_path="annotations.yaml", + annotations_field="imported_annotations", + yolo_type="segmentation" + ) + """ + + # Make sure the annotation field exists, is a blob field + has the annotation tag + existing_fields = [f for f in self.fields if f.name == field] + if len(existing_fields) != 0: + f = existing_fields[0] + if f.valueType != MetadataFieldType.BLOB: + raise RuntimeError( + f"Field {f.name} is not a blob field. " + f"Choose a new field or an existing blob field to upload annotations to." + ) + self.metadata_field(field).set_type(bytes).set_annotation().apply() + + # Run import + importer = AnnotationImporter( + ds=self, + annotations_type=annotation_type, + annotations_file=path, + load_from=load_from, + **kwargs, + ) + annotation_dict = importer.import_annotations() + + annotation_dict = importer.remap_annotations(annotation_dict, remap_func=remapping_function) + + with rich_console.status("Converting annotations to tasks..."): + tasks = importer.convert_to_ls_tasks(annotation_dict) + + with self.metadata_context() as ctx: + for dp, task in tasks.items(): + ctx.update_metadata(dp, {field: task}) + + log_message(f'Done! Uploaded annotations for {len(tasks)} datapoints to field "{field}"') + class MetadataContextManager: """ @@ -1510,8 +1598,9 @@ def _get_datetime_utc_offset(t): @dataclass class DatasourceQuery(DataClassJsonMixin): as_of: Optional[int] = field(default=None, metadata=config(exclude=exclude_if_none, letter_case=LetterCase.CAMEL)) - time_zone: Optional[str] = field(default=None, - metadata=config(exclude=exclude_if_none, letter_case=LetterCase.CAMEL)) + time_zone: Optional[str] = field( + default=None, metadata=config(exclude=exclude_if_none, letter_case=LetterCase.CAMEL) + ) select: Optional[List[Dict]] = field(default=None, metadata=config(exclude=exclude_if_none)) filter: "QueryFilterTree" = field( default=QueryFilterTree(), diff --git a/setup.py b/setup.py index ede0121e..13916801 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,10 @@ def get_version(rel_path: str) -> str: "python-dateutil", "tenacity~=8.2.3", "boto3", + # FIXME: GO BACK TO REGULAR IMPORT AFTER THIS IS MERGED + "dagshub-annotation-converter @ " + + "git+https://github.com/DagsHub/" + + "dagshub-annotation-converter@refactor/static-annotations#egg=dagshub-annotation-converter", ] extras_require = { diff --git a/tests/data_engine/annotation_import/test_load_location.py b/tests/data_engine/annotation_import/test_load_location.py new file mode 100644 index 00000000..41940786 --- /dev/null +++ b/tests/data_engine/annotation_import/test_load_location.py @@ -0,0 +1,38 @@ +import os +from pathlib import Path +from typing import cast + +import pytest + +from dagshub.data_engine.annotation.importer import AnnotationImporter, AnnotationsNotFoundError +from dagshub.data_engine.model.datasource import Datasource +from tests.mocks.repo_api import MockRepoAPI +from tests.util import remember_cwd + + +@pytest.fixture +def annotation_ds(ds) -> Datasource: + ds.source.path = "repo://kirill/repo:main/data/images" + + repoApi = cast(MockRepoAPI, ds.source.repoApi) + repoApi.add_repo_file("data/labels/1.txt", b"1") + + return ds + + +def test_load_location_on_disk(annotation_ds, tmp_path): + """Also tests that disk takes priority over repo.""" + with remember_cwd(): + os.chdir(tmp_path) + Path("data/labels").mkdir(parents=True) + Path("data/labels/1.txt").write_text("1") + assert AnnotationImporter.determine_load_location(annotation_ds, "data/labels/1.txt") == "disk" + + +def test_load_location_on_repo(annotation_ds): + assert AnnotationImporter.determine_load_location(annotation_ds, "data/labels/1.txt") == "repo" + + +def test_load_location_fails(annotation_ds): + with pytest.raises(AnnotationsNotFoundError): + AnnotationImporter.determine_load_location(annotation_ds, "random_path") diff --git a/tests/data_engine/annotation_import/test_path_remapping.py b/tests/data_engine/annotation_import/test_path_remapping.py new file mode 100644 index 00000000..81775c19 --- /dev/null +++ b/tests/data_engine/annotation_import/test_path_remapping.py @@ -0,0 +1,74 @@ +import pytest + +from dagshub.data_engine.annotation.importer import AnnotationImporter, CannotRemapPathError + + +@pytest.mark.parametrize( + "in_path, expected", + [ + ("images/1.png", "data/images/1.png"), + ("images/2.png", "data/images/2.png"), + ("3.png", "data/3.png"), + ("very/long/subpath/4.png", "data/very/long/subpath/4.png"), + ], +) +def test_dp_path_is_longer(in_path, expected): + ann_path = "images/1.png" + dp_path = "data/images/1.png" + + remap_func = AnnotationImporter.generate_path_map_func(ann_path, dp_path) + + actual = remap_func(in_path) + assert actual == expected + + +@pytest.mark.parametrize( + "in_path, expected", + [ + ("data/images/1.png", "images/1.png"), + ("data/images/2.png", "images/2.png"), + ("data/3.png", "3.png"), + ("data/very/long/subpath/4.png", "very/long/subpath/4.png"), + ("5.png", None), + ], +) +def test_dp_path_is_shorter(in_path, expected): + ann_path = "data/images/1.png" + dp_path = "images/1.png" + + remap_func = AnnotationImporter.generate_path_map_func(ann_path, dp_path) + + actual = remap_func(in_path) + assert actual == expected + + +@pytest.mark.parametrize( + "dp_path", + [ + "data/different_prefix/1.png", + "data/images/more/1.png", + "data/more/images/1.png", + "different_prefix/images/1.png", # This case has too many edge cases, so we also don't handle this + ], +) +def test_different_paths_throw_errors(dp_path): + ann_path = "data/images/1.png" + + with pytest.raises(CannotRemapPathError): + AnnotationImporter.generate_path_map_func(ann_path, dp_path) + + +def test_multiple_dp_matching(): + ann_path = "images/blabla/1.png" + candidates = [ + "images/1.png", + "data/images/1.png", + "data/images/blabla/1.png", + "images/blabla/blabla/1.png", + "images/random/1.png", + ] + + expected = "data/images/blabla/1.png" + actual = AnnotationImporter.get_best_fit_datapoint_path(ann_path, candidates) + + assert actual == expected diff --git a/tests/mocks/repo_api.py b/tests/mocks/repo_api.py index af8db654..4344d001 100644 --- a/tests/mocks/repo_api.py +++ b/tests/mocks/repo_api.py @@ -136,6 +136,8 @@ def list_path(self, path: str, revision: Optional[str] = None, include_size: boo revision = self.default_branch content = self.repo_contents.get(revision, {}).get(path) if content is None: + if self.get_file(path, revision) is not None: + return [self.generate_content_api_entry(path)] raise PathNotFoundError return content From 592682745cbc25f46ad0b6a04b11a3bc48b644bc Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 18 Jul 2024 19:10:38 +0300 Subject: [PATCH 2/4] Python 3.8 type hints --- dagshub/data_engine/annotation/importer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dagshub/data_engine/annotation/importer.py b/dagshub/data_engine/annotation/importer.py index 5a04e09e..88018a9d 100644 --- a/dagshub/data_engine/annotation/importer.py +++ b/dagshub/data_engine/annotation/importer.py @@ -1,7 +1,7 @@ from difflib import SequenceMatcher from pathlib import Path, PurePosixPath from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable +from typing import TYPE_CHECKING, Literal, Optional, Union, Sequence, Mapping, Callable, List from dagshub_annotation_converter.converters.cvat import load_cvat_from_zip from dagshub_annotation_converter.converters.yolo import load_yolo_from_fs @@ -161,7 +161,7 @@ def remap_annotations( @staticmethod def guess_annotation_filename_remapping( - annotation_path: str, datapoint_paths: list[str] + annotation_path: str, datapoint_paths: List[str] ) -> Callable[[str], Optional[str]]: """ Guesses the remapping function from the annotations to the data points. @@ -242,7 +242,7 @@ def remove_prefix(x: str) -> Optional[str]: return remove_prefix @staticmethod - def get_best_fit_datapoint_path(ann_path: str, datapoint_paths: list[str]) -> str: + def get_best_fit_datapoint_path(ann_path: str, datapoint_paths: List[str]) -> str: """ Get the datapoint path that is the closest to the annotation path. From 6ebd7aabf6b69d2fe177e437cc7f08646c3a80f7 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Thu, 18 Jul 2024 19:20:36 +0300 Subject: [PATCH 3/4] Add missing init file for tests --- tests/data_engine/annotation_import/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/data_engine/annotation_import/__init__.py diff --git a/tests/data_engine/annotation_import/__init__.py b/tests/data_engine/annotation_import/__init__.py new file mode 100644 index 00000000..e69de29b From 8b17437bfa4f9d0d1b0c7e46b9d77323a96d13c0 Mon Sep 17 00:00:00 2001 From: KBolashev Date: Mon, 26 Aug 2024 14:48:22 +0300 Subject: [PATCH 4/4] Change annotation converter version --- setup.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 13916801..565f5592 100644 --- a/setup.py +++ b/setup.py @@ -42,10 +42,7 @@ def get_version(rel_path: str) -> str: "python-dateutil", "tenacity~=8.2.3", "boto3", - # FIXME: GO BACK TO REGULAR IMPORT AFTER THIS IS MERGED - "dagshub-annotation-converter @ " - + "git+https://github.com/DagsHub/" - + "dagshub-annotation-converter@refactor/static-annotations#egg=dagshub-annotation-converter", + "dagshub-annotation-converter>=0.1.0", ] extras_require = {