From 1c0a49fc80e994a2075084dde4fcf9438d33d053 Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Wed, 2 Aug 2023 20:30:05 +0300 Subject: [PATCH] Add auto-annotation support to SDK and CLI (#6483) Introduce a `cvat-sdk auto-annotate` command that downloads data for a task, then runs a function on the local computer on that data, and uploads resulting annotations back to the task. To support this functionality, add a new SDK module, `cvat_sdk.auto_annotation`, that contains an interface that the functions must follow, and a driver that applies a function to a task. This will let users easily annotate their tasks with custom DL models. --- .github/workflows/full.yml | 13 +- .github/workflows/main.yml | 2 +- CHANGELOG.md | 5 + cvat-cli/src/cvat_cli/__main__.py | 1 + cvat-cli/src/cvat_cli/cli.py | 33 +- cvat-cli/src/cvat_cli/parser.py | 35 + cvat-sdk/cvat_sdk/auto_annotation/__init__.py | 17 + cvat-sdk/cvat_sdk/auto_annotation/driver.py | 303 +++++++++ .../auto_annotation/functions/__init__.py | 0 .../auto_annotation/functions/yolov8n.py | 36 + .../cvat_sdk/auto_annotation/interface.py | 166 +++++ cvat-sdk/cvat_sdk/datasets/common.py | 3 + cvat-sdk/cvat_sdk/datasets/task_dataset.py | 7 +- .../openapi-generator/setup.mustache | 1 + tests/python/cli/example_function.py | 23 + tests/python/cli/test_cli.py | 26 + tests/python/sdk/test_auto_annotation.py | 629 ++++++++++++++++++ tests/python/sdk/test_datasets.py | 1 + 18 files changed, 1293 insertions(+), 8 deletions(-) create mode 100644 cvat-sdk/cvat_sdk/auto_annotation/__init__.py create mode 100644 cvat-sdk/cvat_sdk/auto_annotation/driver.py create mode 100644 cvat-sdk/cvat_sdk/auto_annotation/functions/__init__.py create mode 100644 cvat-sdk/cvat_sdk/auto_annotation/functions/yolov8n.py create mode 100644 cvat-sdk/cvat_sdk/auto_annotation/interface.py create mode 100644 tests/python/cli/example_function.py create mode 100644 tests/python/sdk/test_auto_annotation.py diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml index d2f0a23a3c32..b90a8599c105 100644 --- a/.github/workflows/full.yml +++ b/.github/workflows/full.yml @@ -152,16 +152,19 @@ jobs: name: expected_schema path: cvat/schema-expected.yml - - name: Running REST API and SDK tests - id: run_tests + - name: Generate SDK run: | pip3 install -r cvat-sdk/gen/requirements.txt ./cvat-sdk/gen/generate.sh - pip3 install -r ./tests/python/requirements.txt - pip3 install -e ./cvat-sdk - pip3 install -e ./cvat-cli + - name: Install SDK + run: | + pip3 install -r ./tests/python/requirements.txt \ + -e './cvat-sdk[pytorch,ultralytics]' -e ./cvat-cli + - name: Running REST API and SDK tests + id: run_tests + run: | pytest tests/python/ - name: Creating a log file from cvat containers diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 023645d52380..b1a85b809fd7 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -160,7 +160,7 @@ jobs: - name: Install SDK run: | pip3 install -r ./tests/python/requirements.txt \ - -e './cvat-sdk[pytorch]' -e ./cvat-cli + -e './cvat-sdk[pytorch,ultralytics]' -e ./cvat-cli - name: Run REST API and SDK tests id: run_tests diff --git a/CHANGELOG.md b/CHANGELOG.md index 87143015a4eb..4e2e25a1e774 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - \[SDK\] A `DeferredTqdmProgressReporter` class, which doesn't have glitchy output like `TqdmProgressReporter` in certain circumstances () +- \[SDK, CLI\] A `cvat_sdk.auto_annotation` module that provides + functionality for automatically annotating a task by running a + user-provided function on the local machine, and a corresponding CLI command + (`auto-annotate`) + () ### Changed diff --git a/cvat-cli/src/cvat_cli/__main__.py b/cvat-cli/src/cvat_cli/__main__.py index 673adf3e6ae7..2448587245f9 100755 --- a/cvat-cli/src/cvat_cli/__main__.py +++ b/cvat-cli/src/cvat_cli/__main__.py @@ -59,6 +59,7 @@ def main(args: List[str] = None): "upload": CLI.tasks_upload, "export": CLI.tasks_export, "import": CLI.tasks_import, + "auto-annotate": CLI.tasks_auto_annotate, } parser = make_cmdline_parser() parsed_args = parser.parse_args(args) diff --git a/cvat-cli/src/cvat_cli/cli.py b/cvat-cli/src/cvat_cli/cli.py index 1c480929801a..d0417944aa62 100644 --- a/cvat-cli/src/cvat_cli/cli.py +++ b/cvat-cli/src/cvat_cli/cli.py @@ -4,9 +4,13 @@ from __future__ import annotations +import importlib +import importlib.util import json -from typing import Dict, List, Sequence, Tuple +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple +import cvat_sdk.auto_annotation as cvataa from cvat_sdk import Client, models from cvat_sdk.core.helpers import DeferredTqdmProgressReporter from cvat_sdk.core.proxies.tasks import ResourceType @@ -140,3 +144,30 @@ def tasks_import(self, filename: str, *, status_check_period: int = 2) -> None: status_check_period=status_check_period, pbar=DeferredTqdmProgressReporter(), ) + + def tasks_auto_annotate( + self, + task_id: int, + *, + function_module: Optional[str] = None, + function_file: Optional[Path] = None, + clear_existing: bool = False, + allow_unmatched_labels: bool = False, + ) -> None: + if function_module is not None: + function = importlib.import_module(function_module) + elif function_file is not None: + module_spec = importlib.util.spec_from_file_location("__cvat_function__", function_file) + function = importlib.util.module_from_spec(module_spec) + module_spec.loader.exec_module(function) + else: + assert False, "function identification arguments missing" + + cvataa.annotate_task( + self.client, + task_id, + function, + pbar=DeferredTqdmProgressReporter(), + clear_existing=clear_existing, + allow_unmatched_labels=allow_unmatched_labels, + ) diff --git a/cvat-cli/src/cvat_cli/parser.py b/cvat-cli/src/cvat_cli/parser.py index 32630baaad8a..c1a7e6c3abdb 100644 --- a/cvat-cli/src/cvat_cli/parser.py +++ b/cvat-cli/src/cvat_cli/parser.py @@ -10,6 +10,7 @@ import os import textwrap from distutils.util import strtobool +from pathlib import Path from cvat_sdk.core.proxies.tasks import ResourceType @@ -369,6 +370,40 @@ def make_cmdline_parser() -> argparse.ArgumentParser: help="time interval between checks if archive processing was finished, in seconds", ) + ####################################################################### + # Auto-annotate + ####################################################################### + auto_annotate_task_parser = task_subparser.add_parser( + "auto-annotate", + description="Automatically annotate a CVAT task by running a function on the local machine.", + ) + auto_annotate_task_parser.add_argument("task_id", type=int, help="task ID") + + function_group = auto_annotate_task_parser.add_mutually_exclusive_group(required=True) + + function_group.add_argument( + "--function-module", + metavar="MODULE", + help="qualified name of a module to use as the function", + ) + + function_group.add_argument( + "--function-file", + metavar="PATH", + type=Path, + help="path to a Python source file to use as the function", + ) + + auto_annotate_task_parser.add_argument( + "--clear-existing", action="store_true", help="Remove existing annotations from the task" + ) + + auto_annotate_task_parser.add_argument( + "--allow-unmatched-labels", + action="store_true", + help="Allow the function to declare labels not configured in the task", + ) + return parser diff --git a/cvat-sdk/cvat_sdk/auto_annotation/__init__.py b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py new file mode 100644 index 000000000000..e5dbdf9fcc42 --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from .driver import BadFunctionError, annotate_task +from .interface import ( + DetectionFunction, + DetectionFunctionContext, + DetectionFunctionSpec, + keypoint, + keypoint_spec, + label_spec, + rectangle, + shape, + skeleton, + skeleton_label_spec, +) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/driver.py b/cvat-sdk/cvat_sdk/auto_annotation/driver.py new file mode 100644 index 000000000000..8c1c71b46e8b --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/driver.py @@ -0,0 +1,303 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import logging +from typing import List, Mapping, Optional, Sequence + +import attrs + +import cvat_sdk.models as models +from cvat_sdk.core import Client +from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter +from cvat_sdk.datasets.task_dataset import TaskDataset + +from .interface import DetectionFunction, DetectionFunctionContext, DetectionFunctionSpec + + +class BadFunctionError(Exception): + """ + An exception that signifies that an auto-detection function has violated some constraint + set by its interface. + """ + + +class _AnnotationMapper: + @attrs.frozen + class _MappedLabel: + id: int + sublabel_mapping: Mapping[int, Optional[int]] + expected_num_elements: int = 0 + + _label_mapping: Mapping[int, Optional[_MappedLabel]] + + def _build_mapped_label( + self, fun_label: models.ILabel, ds_labels_by_name: Mapping[str, models.ILabel] + ) -> Optional[_MappedLabel]: + if getattr(fun_label, "attributes", None): + raise BadFunctionError(f"label attributes are currently not supported") + + ds_label = ds_labels_by_name.get(fun_label.name) + if ds_label is None: + if not self._allow_unmatched_labels: + raise BadFunctionError(f"label {fun_label.name!r} is not in dataset") + + self._logger.info( + "label %r is not in dataset; any annotations using it will be ignored", + fun_label.name, + ) + return None + + sl_map = {} + + if getattr(fun_label, "sublabels", []): + fun_label_type = getattr(fun_label, "type", "any") + if fun_label_type != "skeleton": + raise BadFunctionError( + f"label {fun_label.name!r} with sublabels has type {fun_label_type!r} (should be 'skeleton')" + ) + + ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels} + + for fun_sl in fun_label.sublabels: + if not hasattr(fun_sl, "id"): + raise BadFunctionError( + f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has no ID" + ) + + if fun_sl.id in sl_map: + raise BadFunctionError( + f"sublabel {fun_sl.name!r} of label {fun_label.name!r} has same ID as another sublabel ({fun_sl.id})" + ) + + ds_sl = ds_sublabels_by_name.get(fun_sl.name) + if not ds_sl: + if not self._allow_unmatched_labels: + raise BadFunctionError( + f"sublabel {fun_sl.name!r} of label {fun_label.name!r} is not in dataset" + ) + + self._logger.info( + "sublabel %r of label %r is not in dataset; any annotations using it will be ignored", + fun_sl.name, + fun_label.name, + ) + sl_map[fun_sl.id] = None + continue + + sl_map[fun_sl.id] = ds_sl.id + + return self._MappedLabel( + ds_label.id, sublabel_mapping=sl_map, expected_num_elements=len(ds_label.sublabels) + ) + + def __init__( + self, + logger: logging.Logger, + fun_labels: Sequence[models.ILabel], + ds_labels: Sequence[models.ILabel], + *, + allow_unmatched_labels: bool, + ) -> None: + self._logger = logger + self._allow_unmatched_labels = allow_unmatched_labels + + ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels} + + self._label_mapping = {} + + for fun_label in fun_labels: + if not hasattr(fun_label, "id"): + raise BadFunctionError(f"label {fun_label.name!r} has no ID") + + if fun_label.id in self._label_mapping: + raise BadFunctionError( + f"label {fun_label.name} has same ID as another label ({fun_label.id})" + ) + + self._label_mapping[fun_label.id] = self._build_mapped_label( + fun_label, ds_labels_by_name + ) + + def validate_and_remap(self, shapes: List[models.LabeledShapeRequest], ds_frame: int) -> None: + new_shapes = [] + + for shape in shapes: + if hasattr(shape, "id"): + raise BadFunctionError("function output shape with preset id") + + if hasattr(shape, "source"): + raise BadFunctionError("function output shape with preset source") + shape.source = "auto" + + if shape.frame != 0: + raise BadFunctionError( + f"function output shape with unexpected frame number ({shape.frame})" + ) + + shape.frame = ds_frame + + try: + mapped_label = self._label_mapping[shape.label_id] + except KeyError: + raise BadFunctionError( + f"function output shape with unknown label ID ({shape.label_id})" + ) + + if not mapped_label: + continue + + shape.label_id = mapped_label.id + + if getattr(shape, "attributes", None): + raise BadFunctionError( + "function output shape with attributes, which is not yet supported" + ) + + new_shapes.append(shape) + + if shape.type.value == "skeleton": + new_elements = [] + seen_sl_ids = set() + + for element in shape.elements: + if hasattr(element, "id"): + raise BadFunctionError("function output shape element with preset id") + + if hasattr(element, "source"): + raise BadFunctionError("function output shape element with preset source") + element.source = "auto" + + if element.frame != 0: + raise BadFunctionError( + f"function output shape element with unexpected frame number ({element.frame})" + ) + + element.frame = ds_frame + + if element.type.value != "points": + raise BadFunctionError( + f"function output skeleton with element type other than 'points' ({element.type.value})" + ) + + try: + mapped_sl_id = mapped_label.sublabel_mapping[element.label_id] + except KeyError: + raise BadFunctionError( + f"function output shape with unknown sublabel ID ({element.label_id})" + ) + + if not mapped_sl_id: + continue + + if mapped_sl_id in seen_sl_ids: + raise BadFunctionError( + "function output skeleton with multiple elements with same sublabel" + ) + + element.label_id = mapped_sl_id + + seen_sl_ids.add(mapped_sl_id) + + new_elements.append(element) + + if len(new_elements) != mapped_label.expected_num_elements: + # new_elements could only be shorter than expected, + # because the reverse would imply that there are more distinct sublabel IDs + # than are actually defined in the dataset. + assert len(new_elements) < mapped_label.expected_num_elements + + raise BadFunctionError( + f"function output skeleton with fewer elements than expected ({len(new_elements)} vs {mapped_label.expected_num_elements})" + ) + + shape.elements[:] = new_elements + else: + if getattr(shape, "elements", None): + raise BadFunctionError("function output non-skeleton shape with elements") + + shapes[:] = new_shapes + + +@attrs.frozen +class _DetectionFunctionContextImpl(DetectionFunctionContext): + frame_name: str + + +def annotate_task( + client: Client, + task_id: int, + function: DetectionFunction, + *, + pbar: Optional[ProgressReporter] = None, + clear_existing: bool = False, + allow_unmatched_labels: bool = False, +) -> None: + """ + Downloads data for the task with the given ID, applies the given function to it + and uploads the resulting annotations back to the task. + + Only tasks with 2D image (not video) data are supported at the moment. + + client is used to make all requests to the CVAT server. + + Currently, the only type of auto-annotation function supported is the detection function. + A function of this type is applied independently to each image in the task. + The resulting annotations are then combined and modified as follows: + + * The label IDs are replaced with the IDs of the corresponding labels in the task. + * The frame numbers are replaced with the frame number of the image. + * The sources are set to "auto". + + See the documentation for DetectionFunction for more details. + + If the function is found to violate any constraints set in its interface, BadFunctionError + is raised. + + pbar, if supplied, is used to report progress information. + + If clear_existing is true, any annotations already existing in the tesk are removed. + Otherwise, they are kept, and the new annotations are added to them. + + The allow_unmatched_labels parameter controls the behavior in the case when a detection + function declares a label in its spec that has no corresponding label in the task. + If it's set to true, then such labels are allowed, and any annotations returned by the + function that refer to this label are ignored. Otherwise, BadFunctionError is raised. + """ + + if pbar is None: + pbar = NullProgressReporter() + + dataset = TaskDataset(client, task_id) + + assert isinstance(function.spec, DetectionFunctionSpec) + + mapper = _AnnotationMapper( + client.logger, + function.spec.labels, + dataset.labels, + allow_unmatched_labels=allow_unmatched_labels, + ) + + shapes = [] + + with pbar.task(total=len(dataset.samples), unit="samples"): + for sample in pbar.iter(dataset.samples): + frame_shapes = function.detect( + _DetectionFunctionContextImpl(sample.frame_name), sample.media.load_image() + ) + mapper.validate_and_remap(frame_shapes, sample.frame_index) + shapes.extend(frame_shapes) + + client.logger.info("Uploading annotations to task %d", task_id) + + if clear_existing: + client.tasks.api.update_annotations( + task_id, task_annotations_update_request=models.LabeledDataRequest(shapes=shapes) + ) + else: + client.tasks.api.partial_update_annotations( + "create", + task_id, + patched_labeled_data_request=models.PatchedLabeledDataRequest(shapes=shapes), + ) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/__init__.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/cvat-sdk/cvat_sdk/auto_annotation/functions/yolov8n.py b/cvat-sdk/cvat_sdk/auto_annotation/functions/yolov8n.py new file mode 100644 index 000000000000..325f6036a633 --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/functions/yolov8n.py @@ -0,0 +1,36 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +""" +An auto-annotation detection function powered by the YOLOv8n model. +Outputs rectangles. +""" + +from typing import Iterator, List + +import PIL.Image +from ultralytics import YOLO +from ultralytics.engine.results import Results + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models + +_model = YOLO("yolov8n.pt") + +spec = cvataa.DetectionFunctionSpec( + labels=[cvataa.label_spec(name, id) for id, name in _model.names.items()], +) + + +def _yolo_to_cvat(results: List[Results]) -> Iterator[models.LabeledShapeRequest]: + for result in results: + for box, label in zip(result.boxes.xyxy, result.boxes.cls): + yield cvataa.rectangle( + label_id=int(label.item()), + points=[p.item() for p in box], + ) + + +def detect(context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + return list(_yolo_to_cvat(_model.predict(source=image, verbose=False))) diff --git a/cvat-sdk/cvat_sdk/auto_annotation/interface.py b/cvat-sdk/cvat_sdk/auto_annotation/interface.py new file mode 100644 index 000000000000..160d12533d63 --- /dev/null +++ b/cvat-sdk/cvat_sdk/auto_annotation/interface.py @@ -0,0 +1,166 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import abc +from typing import List, Sequence + +import attrs +import PIL.Image +from typing_extensions import Protocol + +import cvat_sdk.models as models + + +@attrs.frozen(kw_only=True) +class DetectionFunctionSpec: + """ + Static information about an auto-annotation detection function. + """ + + labels: Sequence[models.PatchedLabelRequest] + """ + Information about labels that the function supports. + + The members of the sequence must follow the same constraints as if they were being + used to create a CVAT project, and the following additional constraints: + + * The id attribute must be set to a distinct integer. + + * The id attribute of any sublabels must be set to an integer, distinct between all + sublabels of the same parent label. + + * There must not be any attributes (attribute support may be added in a future version). + + It's recommented to use the helper factory functions (label_spec, skeleton_label_spec, + keypoint_spec) to create the label objects, as they are more concise than the model + constructors and help to follow some of the constraints. + """ + + +class DetectionFunctionContext(metaclass=abc.ABCMeta): + """ + Information that is supplied to an auto-annotation detection function. + """ + + @property + @abc.abstractmethod + def frame_name(self) -> str: + """ + The file name of the frame that the current image corresponds to in + the dataset. + """ + ... + + +class DetectionFunction(Protocol): + """ + The interface that an auto-annotation detection function must implement. + + A detection function is supposed to accept an image and return a list of shapes + describing objects in that image. + + Since the same function could be used with multiple datasets, it needs some way + to refer to labels without using dataset-specific label IDs. The way this is + accomplished is that the function declares its own labels via the spec attribute, + and then refers to those labels in the returned annotations. The caller then matches + up the labels from the function's spec with the labels in the actual dataset, and + replaces the label IDs in the returned annotations with IDs of the corresponding + labels in the dataset. + + The matching of labels between the function and the dataset is done by name. + Therefore, a function can be used with a dataset if they have (at least some) labels + that have the same name. + """ + + @property + def spec(self) -> DetectionFunctionSpec: + """Returns the function's spec.""" + ... + + def detect( + self, context: DetectionFunctionContext, image: PIL.Image.Image + ) -> List[models.LabeledShapeRequest]: + """ + Detects objects on the supplied image and returns the results. + + The supplied context will contain information about the current image. + + The returned LabeledShapeRequest objects must follow general constraints + imposed by the data model (such as the number of points in a shape), + as well as the following additional constraints: + + * The id attribute must not be set. + + * The source attribute must not be set. + + * The frame_id attribute must be set to 0. + + * The label_id attribute must equal one of the label IDs + in the function spec. + + * There must not be any attributes (attribute support may be added in a + future version). + + * The above constraints also apply to each sub-shape (element of a shape), + except that the label_id of a sub-shape must equal one of the sublabel IDs + of the label of its parent shape. + + It's recommented to use the helper factory functions (shape, rectangle, skeleton, + keypoint) to create the shape objects, as they are more concise than the model + constructors and help to follow some of the constraints. + + The function must not retain any references to the returned objects, + so that the caller may freely modify them. + """ + ... + + +# spec factories + + +# pylint: disable-next=redefined-builtin +def label_spec(name: str, id: int, **kwargs) -> models.PatchedLabelRequest: + """Helper factory function for PatchedLabelRequest.""" + return models.PatchedLabelRequest(name=name, id=id, **kwargs) + + +# pylint: disable-next=redefined-builtin +def skeleton_label_spec( + name: str, id: int, sublabels: Sequence[models.SublabelRequest], **kwargs +) -> models.PatchedLabelRequest: + """Helper factory function for PatchedLabelRequest with type="skeleton".""" + return models.PatchedLabelRequest(name=name, id=id, type="skeleton", sublabels=sublabels) + + +# pylint: disable-next=redefined-builtin +def keypoint_spec(name: str, id: int, **kwargs) -> models.SublabelRequest: + """Helper factory function for SublabelRequest.""" + return models.SublabelRequest(name=name, id=id, **kwargs) + + +# annotation factories + + +def shape(label_id: int, **kwargs) -> models.LabeledShapeRequest: + """Helper factory function for LabeledShapeRequest with frame=0.""" + return models.LabeledShapeRequest(label_id=label_id, frame=0, **kwargs) + + +def rectangle(label_id: int, points: Sequence[float], **kwargs) -> models.LabeledShapeRequest: + """Helper factory function for LabeledShapeRequest with frame=0 and type="rectangle".""" + return shape(label_id, type="rectangle", points=points, **kwargs) + + +def skeleton( + label_id: int, elements: Sequence[models.SubLabeledShapeRequest], **kwargs +) -> models.LabeledShapeRequest: + """Helper factory function for LabeledShapeRequest with frame=0 and type="skeleton".""" + return shape(label_id, type="skeleton", elements=elements, **kwargs) + + +def keypoint(label_id: int, points: Sequence[float], **kwargs) -> models.SubLabeledShapeRequest: + """Helper factory function for SubLabeledShapeRequest with frame=0 and type="points".""" + return models.SubLabeledShapeRequest( + label_id=label_id, frame=0, type="points", points=points, **kwargs + ) diff --git a/cvat-sdk/cvat_sdk/datasets/common.py b/cvat-sdk/cvat_sdk/datasets/common.py index 2b8269dbd567..c621a2d2ed33 100644 --- a/cvat-sdk/cvat_sdk/datasets/common.py +++ b/cvat-sdk/cvat_sdk/datasets/common.py @@ -50,6 +50,9 @@ class Sample: frame_index: int """Index of the corresponding frame in its task.""" + frame_name: str + """File name of the frame in its task.""" + annotations: FrameAnnotations """Annotations belonging to the frame.""" diff --git a/cvat-sdk/cvat_sdk/datasets/task_dataset.py b/cvat-sdk/cvat_sdk/datasets/task_dataset.py index 586070457934..111528d43715 100644 --- a/cvat-sdk/cvat_sdk/datasets/task_dataset.py +++ b/cvat-sdk/cvat_sdk/datasets/task_dataset.py @@ -126,7 +126,12 @@ def ensure_chunk(chunk_index): # TODO: tracks? self._samples = [ - Sample(frame_index=k, annotations=v, media=self._TaskMediaElement(self, k)) + Sample( + frame_index=k, + frame_name=data_meta.frames[k].name, + annotations=v, + media=self._TaskMediaElement(self, k), + ) for k, v in self._frame_annotations.items() ] diff --git a/cvat-sdk/gen/templates/openapi-generator/setup.mustache b/cvat-sdk/gen/templates/openapi-generator/setup.mustache index eb89f5d20554..fc6f34144da3 100644 --- a/cvat-sdk/gen/templates/openapi-generator/setup.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/setup.mustache @@ -78,6 +78,7 @@ setup( install_requires=BASE_REQUIREMENTS, extras_require={ "pytorch": ['torch', 'torchvision'], + "ultralytics": ["ultralytics"], }, package_dir={"": "."}, packages=find_packages(include=["cvat_sdk*"]), diff --git a/tests/python/cli/example_function.py b/tests/python/cli/example_function.py new file mode 100644 index 000000000000..4b1b41857825 --- /dev/null +++ b/tests/python/cli/example_function.py @@ -0,0 +1,23 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from typing import List + +import cvat_sdk.auto_annotation as cvataa +import cvat_sdk.models as models +import PIL.Image + +spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 0), + ], +) + + +def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image +) -> List[models.LabeledShapeRequest]: + return [ + cvataa.rectangle(0, [1, 2, 3, 4]), + ] diff --git a/tests/python/cli/test_cli.py b/tests/python/cli/test_cli.py index 6dbcbb5241fd..fbb6f73fe5fa 100644 --- a/tests/python/cli/test_cli.py +++ b/tests/python/cli/test_cli.py @@ -302,3 +302,29 @@ def test_can_control_organization_context(self): all_task_ids = list(map(int, self.run_cli("ls").split())) assert personal_task_id in all_task_ids assert org_task_id in all_task_ids + + def test_auto_annotate_with_module(self, fxt_new_task: Task): + annotations = fxt_new_task.get_annotations() + assert not annotations.shapes + + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-module={__package__}.example_function", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes + + def test_auto_annotate_with_file(self, fxt_new_task: Task): + annotations = fxt_new_task.get_annotations() + assert not annotations.shapes + + self.run_cli( + "auto-annotate", + str(fxt_new_task.id), + f"--function-file={Path(__file__).with_name('example_function.py')}", + ) + + annotations = fxt_new_task.get_annotations() + assert annotations.shapes diff --git a/tests/python/sdk/test_auto_annotation.py b/tests/python/sdk/test_auto_annotation.py new file mode 100644 index 000000000000..05814affee7f --- /dev/null +++ b/tests/python/sdk/test_auto_annotation.py @@ -0,0 +1,629 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import io +import sys +from logging import Logger +from pathlib import Path +from types import SimpleNamespace as namespace +from typing import Any, List, Tuple + +import cvat_sdk.auto_annotation as cvataa +import PIL.Image +import pytest +from cvat_sdk import Client, models +from cvat_sdk.core.proxies.tasks import ResourceType + +from shared.utils.helpers import generate_image_file + +from .util import make_pbar + +try: + import numpy as np + from ultralytics.engine.results import Results as UResults +except ModuleNotFoundError: + np = None + UResults = None + + +@pytest.fixture(autouse=True) +def _common_setup( + tmp_path: Path, + fxt_login: Tuple[Client, str], + fxt_logger: Tuple[Logger, io.StringIO], +): + logger = fxt_logger[0] + client = fxt_login[0] + client.logger = logger + client.config.cache_dir = tmp_path / "cache" + + api_client = client.api_client + for k in api_client.configuration.logger: + api_client.configuration.logger[k] = logger + + +class TestTaskAutoAnnotation: + @pytest.fixture(autouse=True) + def setup( + self, + tmp_path: Path, + fxt_login: Tuple[Client, str], + ): + self.client = fxt_login[0] + self.images = [ + generate_image_file("1.png", size=(333, 333), color=(0, 0, 0)), + generate_image_file("2.png", size=(333, 333), color=(100, 100, 100)), + ] + + image_dir = tmp_path / "images" + image_dir.mkdir() + + image_paths = [] + for image in self.images: + image_path = image_dir / image.name + image_path.write_bytes(image.getbuffer()) + image_paths.append(image_path) + + self.task = self.client.tasks.create_from_data( + models.TaskWriteRequest( + "Auto-annotation test task", + labels=[ + models.PatchedLabelRequest(name="person"), + models.PatchedLabelRequest(name="car"), + models.PatchedLabelRequest( + name="cat", + type="skeleton", + sublabels=[ + models.SublabelRequest(name="head"), + models.SublabelRequest(name="tail"), + ], + ), + ], + ), + resource_type=ResourceType.LOCAL, + resources=image_paths, + ) + + task_labels = self.task.get_labels() + self.task_labels_by_id = {label.id: label for label in task_labels} + self.cat_sublabels_by_id = { + sl.id: sl + for sl in next(label for label in task_labels if label.name == "cat").sublabels + } + + # The initial annotation is just to check that it gets erased after auto-annotation + self.task.update_annotations( + models.PatchedLabeledDataRequest( + shapes=[ + models.LabeledShapeRequest( + frame=0, + label_id=next(iter(self.task_labels_by_id)), + type="rectangle", + points=[1.0, 2.0, 3.0, 4.0], + ), + ], + ) + ) + + def test_detection_rectangle(self): + spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 123), + cvataa.label_spec("bicycle (should be ignored)", 456), + ], + ) + + def detect( + context: cvataa.DetectionFunctionContext, image: PIL.Image.Image + ) -> List[models.LabeledShapeRequest]: + assert context.frame_name in {"1.png", "2.png"} + assert image.width == image.height == 333 + return [ + cvataa.rectangle( + 123, # car + # produce different coordinates for different images + [*image.getpixel((0, 0)), 300 + int(context.frame_name[0])], + ), + cvataa.shape( + 456, # ignored + type="points", + points=[1, 1], + ), + ] + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + clear_existing=True, + allow_unmatched_labels=True, + ) + + annotations = self.task.get_annotations() + + shapes = sorted(annotations.shapes, key=lambda shape: shape.frame) + + assert len(shapes) == 2 + + for i, shape in enumerate(shapes): + assert shape.frame == i + assert shape.type.value == "rectangle" + assert self.task_labels_by_id[shape.label_id].name == "car" + assert shape.points[3] in {301, 302} + + assert shapes[0].points[0] != shapes[1].points[0] + assert shapes[0].points[3] != shapes[1].points[3] + + def test_detection_skeleton(self): + spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.skeleton_label_spec( + "cat", + 123, + [ + cvataa.keypoint_spec("head", 10), + cvataa.keypoint_spec("torso (should be ignored)", 20), + cvataa.keypoint_spec("tail", 30), + ], + ), + ], + ) + + def detect(context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + assert image.width == image.height == 333 + return [ + cvataa.skeleton( + 123, # cat + [ + # ignored + cvataa.keypoint(20, [20, 20]), + # tail + cvataa.keypoint(30, [30, 30]), + # head + cvataa.keypoint(10, [10, 10]), + ], + ), + ] + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + clear_existing=True, + allow_unmatched_labels=True, + ) + + annotations = self.task.get_annotations() + + shapes = sorted(annotations.shapes, key=lambda shape: shape.frame) + + assert len(shapes) == 2 + + for i, shape in enumerate(shapes): + assert shape.frame == i + assert shape.type.value == "skeleton" + assert self.task_labels_by_id[shape.label_id].name == "cat" + assert len(shape.elements) == 2 + + elements = sorted( + shape.elements, key=lambda s: self.cat_sublabels_by_id[s.label_id].name + ) + + for element in elements: + assert element.frame == i + assert element.type.value == "points" + + assert self.cat_sublabels_by_id[elements[0].label_id].name == "head" + assert elements[0].points == [10, 10] + assert self.cat_sublabels_by_id[elements[1].label_id].name == "tail" + assert elements[1].points == [30, 30] + + def test_progress_reporting(self): + spec = cvataa.DetectionFunctionSpec(labels=[]) + + def detect(context, image): + return [] + + file = io.StringIO() + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + pbar=make_pbar(file), + ) + + assert "100%" in file.getvalue() + + def test_detection_without_clearing(self): + spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 123), + ], + ) + + def detect(context, image: PIL.Image.Image) -> List[models.LabeledShapeRequest]: + return [ + cvataa.rectangle( + 123, # car + [5, 6, 7, 8], + rotation=10, + ), + ] + + cvataa.annotate_task( + self.client, + self.task.id, + namespace(spec=spec, detect=detect), + clear_existing=False, + ) + + annotations = self.task.get_annotations() + + shapes = sorted(annotations.shapes, key=lambda shape: (shape.frame, shape.rotation)) + + # original annotation + assert shapes[0].points == [1, 2, 3, 4] + assert shapes[0].rotation == 0 + + # new annotations + for i in (1, 2): + assert shapes[i].points == [5, 6, 7, 8] + assert shapes[i].rotation == 10 + + def _test_bad_function_spec(self, spec: cvataa.DetectionFunctionSpec, exc_match: str) -> None: + def detect(context, image): + assert False + + with pytest.raises(cvataa.BadFunctionError, match=exc_match): + cvataa.annotate_task(self.client, self.task.id, namespace(spec=spec, detect=detect)) + + def test_attributes(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec( + "car", + 123, + attributes=[ + models.AttributeRequest( + "age", + mutable=False, + input_type="number", + values=["0", "100", "1"], + default_value="0", + ) + ], + ), + ], + ), + "currently not supported", + ) + + def test_label_not_in_dataset(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[cvataa.label_spec("dog", 123)], + ), + "not in dataset", + ) + + def test_label_without_id(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + models.PatchedLabelRequest( + name="car", + ), + ], + ), + "label .+ has no ID", + ) + + def test_duplicate_label_id(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 123), + cvataa.label_spec("bicycle", 123), + ], + ), + "same ID as another label", + ) + + def test_non_skeleton_sublabels(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec( + "car", + 123, + sublabels=[models.SublabelRequest("wheel", id=1)], + ), + ], + ), + "should be 'skeleton'", + ) + + def test_sublabel_without_id(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + cvataa.skeleton_label_spec( + "car", + 123, + [models.SublabelRequest("wheel")], + ), + ], + ), + "sublabel .+ of label .+ has no ID", + ) + + def test_duplicate_sublabel_id(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + cvataa.skeleton_label_spec( + "cat", + 123, + [ + cvataa.keypoint_spec("head", 1), + cvataa.keypoint_spec("tail", 1), + ], + ), + ], + ), + "same ID as another sublabel", + ) + + def test_sublabel_not_in_dataset(self): + self._test_bad_function_spec( + cvataa.DetectionFunctionSpec( + labels=[ + cvataa.skeleton_label_spec("cat", 123, [cvataa.keypoint_spec("nose", 1)]), + ], + ), + "not in dataset", + ) + + def _test_bad_function_detect(self, detect, exc_match: str) -> None: + spec = cvataa.DetectionFunctionSpec( + labels=[ + cvataa.label_spec("car", 123), + cvataa.skeleton_label_spec( + "cat", + 456, + [ + cvataa.keypoint_spec("head", 12), + cvataa.keypoint_spec("tail", 34), + ], + ), + ], + ) + + with pytest.raises(cvataa.BadFunctionError, match=exc_match): + cvataa.annotate_task(self.client, self.task.id, namespace(spec=spec, detect=detect)) + + def test_preset_shape_id(self): + self._test_bad_function_detect( + lambda context, image: [ + models.LabeledShapeRequest( + type="rectangle", frame=0, label_id=123, id=1111, points=[1, 2, 3, 4] + ), + ], + "shape with preset id", + ) + + def test_preset_shape_source(self): + self._test_bad_function_detect( + lambda context, image: [ + models.LabeledShapeRequest( + type="rectangle", frame=0, label_id=123, source="manual", points=[1, 2, 3, 4] + ), + ], + "shape with preset source", + ) + + def test_bad_shape_frame_number(self): + self._test_bad_function_detect( + lambda context, image: [ + models.LabeledShapeRequest( + type="rectangle", + frame=1, + label_id=123, + points=[1, 2, 3, 4], + ), + ], + "unexpected frame number", + ) + + def test_unknown_label_id(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.rectangle(111, [1, 2, 3, 4]), + ], + "unknown label ID", + ) + + def test_shape_with_attributes(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.rectangle( + 123, + [1, 2, 3, 4], + attributes=[ + models.AttributeValRequest(spec_id=1, value="asdf"), + ], + ), + ], + "shape with attributes", + ) + + def test_preset_element_id(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton( + 456, + [ + models.SubLabeledShapeRequest( + type="points", frame=0, label_id=12, id=1111, points=[1, 2] + ), + ], + ), + ], + "element with preset id", + ) + + def test_preset_element_source(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton( + 456, + [ + models.SubLabeledShapeRequest( + type="points", frame=0, label_id=12, source="manual", points=[1, 2] + ), + ], + ), + ], + "element with preset source", + ) + + def test_bad_element_frame_number(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton( + 456, + [ + models.SubLabeledShapeRequest( + type="points", frame=1, label_id=12, points=[1, 2] + ), + ], + ), + ], + "element with unexpected frame number", + ) + + def test_non_points_element(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton( + 456, + [ + models.SubLabeledShapeRequest( + type="rectangle", frame=0, label_id=12, points=[1, 2, 3, 4] + ), + ], + ), + ], + "element type other than 'points'", + ) + + def test_unknown_sublabel_id(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton(456, [cvataa.keypoint(56, [1, 2])]), + ], + "unknown sublabel ID", + ) + + def test_multiple_elements_with_same_sublabel(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton( + 456, + [ + cvataa.keypoint(12, [1, 2]), + cvataa.keypoint(12, [3, 4]), + ], + ), + ], + "multiple elements with same sublabel", + ) + + def test_not_enough_elements(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.skeleton(456, [cvataa.keypoint(12, [1, 2])]), + ], + "with fewer elements than expected", + ) + + def test_non_skeleton_with_elements(self): + self._test_bad_function_detect( + lambda context, image: [ + cvataa.shape( + 456, + type="rectangle", + elements=[cvataa.keypoint(12, [1, 2])], + ), + ], + "non-skeleton shape with elements", + ) + + +class FakeYolo: + def __init__(self, *args, **kwargs) -> None: + pass + + names = {42: "person"} + + def predict(self, source: Any, **kwargs) -> "List[UResults]": + return [ + UResults( + orig_img=np.zeros([100, 100, 3]), + path=None, + names=self.names, + boxes=np.array([[1, 2, 3, 4, 0.9, 42]]), + ) + ] + + +@pytest.mark.skipif(UResults is None, reason="Ultralytics is not installed") +class TestAutoAnnotationFunctions: + @pytest.fixture(autouse=True) + def setup( + self, + tmp_path: Path, + fxt_login: Tuple[Client, str], + ): + self.client = fxt_login[0] + self.image = generate_image_file("1.png", size=(100, 100)) + + image_dir = tmp_path / "images" + image_dir.mkdir() + + image_path = image_dir / self.image.name + image_path.write_bytes(self.image.getbuffer()) + + self.task = self.client.tasks.create_from_data( + models.TaskWriteRequest( + "Auto-annotation test task", + labels=[ + models.PatchedLabelRequest(name="person"), + ], + ), + resources=[image_path], + ) + + task_labels = self.task.get_labels() + self.task_labels_by_id = {label.id: label for label in task_labels} + + def test_yolov8n(self, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr("ultralytics.YOLO", FakeYolo) + + import cvat_sdk.auto_annotation.functions.yolov8n as yolov8n + + try: + cvataa.annotate_task(self.client, self.task.id, yolov8n) + + annotations = self.task.get_annotations() + + assert len(annotations.shapes) == 1 + assert self.task_labels_by_id[annotations.shapes[0].label_id].name == "person" + assert annotations.shapes[0].type.value == "rectangle" + assert annotations.shapes[0].points == [1, 2, 3, 4] + + finally: + del sys.modules[yolov8n.__name__] diff --git a/tests/python/sdk/test_datasets.py b/tests/python/sdk/test_datasets.py index 67204e4c26c9..35b2339ec67e 100644 --- a/tests/python/sdk/test_datasets.py +++ b/tests/python/sdk/test_datasets.py @@ -101,6 +101,7 @@ def test_basic(self): for index, sample in enumerate(dataset.samples): assert sample.frame_index == index + assert sample.frame_name == self.images[index].name actual_image = sample.media.load_image() expected_image = PIL.Image.open(self.images[index])