From 86160ee37cb8684de7550edcddddd70ffc171063 Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Thu, 8 Dec 2022 18:28:55 +0300 Subject: [PATCH] SDK: Add an adapter layer that presents a CVAT task as a torchvision dataset (#5417) --- .github/workflows/full.yml | 2 +- .github/workflows/main.yml | 2 +- .github/workflows/schedule.yml | 2 +- CHANGELOG.md | 2 + cvat-sdk/cvat_sdk/pytorch/__init__.py | 359 ++++++++++++++++++ .../openapi-generator/setup.mustache | 3 + tests/python/sdk/test_pytorch.py | 207 ++++++++++ tests/python/shared/utils/helpers.py | 6 +- 8 files changed, 577 insertions(+), 6 deletions(-) create mode 100644 cvat-sdk/cvat_sdk/pytorch/__init__.py create mode 100644 tests/python/sdk/test_pytorch.py diff --git a/.github/workflows/full.yml b/.github/workflows/full.yml index 7d03cad85235..7772637f6c4c 100644 --- a/.github/workflows/full.yml +++ b/.github/workflows/full.yml @@ -196,7 +196,7 @@ jobs: - name: Running REST API and SDK tests run: | - pip3 install --user /tmp/cvat_sdk/ + pip3 install --user '/tmp/cvat_sdk/[pytorch]' pip3 install --user cvat-cli/ pip3 install --user -r tests/python/requirements.txt pytest tests/python -s -v diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6e4c10916c27..a647b142a983 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -164,7 +164,7 @@ jobs: - name: Running REST API and SDK tests run: | - pip3 install --user /tmp/cvat_sdk/ + pip3 install --user '/tmp/cvat_sdk/[pytorch]' pip3 install --user cvat-cli/ pip3 install --user -r tests/python/requirements.txt pytest tests/python/ -s -v diff --git a/.github/workflows/schedule.yml b/.github/workflows/schedule.yml index 3d617bad2c8f..61bfc638cba2 100644 --- a/.github/workflows/schedule.yml +++ b/.github/workflows/schedule.yml @@ -235,7 +235,7 @@ jobs: gen/generate.sh cd .. - pip3 install --user cvat-sdk/ + pip3 install --user 'cvat-sdk/[pytorch]' pip3 install --user cvat-cli/ pip3 install --user -r tests/python/requirements.txt pytest tests/python/ diff --git a/CHANGELOG.md b/CHANGELOG.md index bc95ae9eb698..dc50413c5773 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ from online detectors & interactors) ( - Authentication with social accounts google & github (, , ) - REST API tests to export job datasets & annotations and validate their structure () - Propagation backward on UI () +- A PyTorch dataset adapter layer in the SDK + () ### Changed - `api/docs`, `api/swagger`, `api/schema`, `server/about` endpoints now allow unauthorized access (, ) diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py new file mode 100644 index 000000000000..55b88186e7a5 --- /dev/null +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -0,0 +1,359 @@ +import base64 +import collections +import json +import os +import shutil +import types +import zipfile +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import ( + Callable, + Dict, + FrozenSet, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, +) + +import appdirs +import attrs +import attrs.validators +import PIL.Image +import torchvision.datasets +from typing_extensions import TypedDict + +import cvat_sdk.core +import cvat_sdk.core.exceptions +from cvat_sdk.api_client.model_utils import to_json +from cvat_sdk.core.utils import atomic_writer +from cvat_sdk.models import DataMetaRead, LabeledData, LabeledImage, LabeledShape, TaskRead + +_ModelType = TypeVar("_ModelType") + +_CACHE_DIR = Path(appdirs.user_cache_dir("cvat-sdk", "CVAT.ai")) +_NUM_DOWNLOAD_THREADS = 4 + + +class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException): + pass + + +@attrs.frozen +class FrameAnnotations: + """ + Contains annotations that pertain to a single frame. + """ + + tags: List[LabeledImage] = attrs.Factory(list) + shapes: List[LabeledShape] = attrs.Factory(list) + + +@attrs.frozen +class Target: + """ + Non-image data for a dataset sample. + """ + + annotations: FrameAnnotations + """Annotations for the frame corresponding to the sample.""" + + label_id_to_index: Mapping[int, int] + """ + A mapping from label_id values in `LabeledImage` and `LabeledShape` objects + to an index in the range [0, num_labels), where num_labels is the number of labels + defined in the task. This mapping is consistent across all samples for a given task. + """ + + +class TaskVisionDataset(torchvision.datasets.VisionDataset): + """ + Represents a task on a CVAT server as a PyTorch Dataset. + + This dataset contains one sample for each frame in the task, in the same + order as the frames are in the task. Deleted frames are omitted. + Before transforms are applied, each sample is a tuple of + (image, target), where: + + * image is a `PIL.Image.Image` object for the corresponding frame. + * target is a `Target` object containing annotations for the frame. + + This class caches all data and annotations for the task on the local file system + during construction. If the task is updated on the server, the cache is updated. + + Limitations: + + * Only tasks with image (not video) data are supported at the moment. + * Track annotations are currently not accessible. + """ + + def __init__( + self, + client: cvat_sdk.core.Client, + task_id: int, + *, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + """ + Creates a dataset corresponding to the task with ID `task_id` on the + server that `client` is connected to. + + `transforms`, `transform` and `target_transforms` are optional transformation + functions; see the documentation for `torchvision.datasets.VisionDataset` for + more information. + """ + + self._logger = client.logger + + self._logger.info(f"Fetching task {task_id}...") + self._task = client.tasks.retrieve(task_id) + + if not self._task.size or not self._task.data_chunk_size: + raise UnsupportedDatasetError("The task has no data") + + if self._task.data_original_chunk_type != "imageset": + raise UnsupportedDatasetError( + f"{self.__class__.__name__} only supports tasks with image chunks;" + f" current chunk type is {self._task.data_original_chunk_type!r}" + ) + + # Base64-encode the name to avoid FS-unsafe characters (like slashes) + server_dir_name = ( + base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() + ) + server_dir = _CACHE_DIR / f"servers/{server_dir_name}" + + self._task_dir = server_dir / f"tasks/{self._task.id}" + self._initialize_task_dir() + + super().__init__( + os.fspath(self._task_dir), + transforms=transforms, + transform=transform, + target_transform=target_transform, + ) + + data_meta = self._ensure_model( + "data_meta.json", DataMetaRead, self._task.get_meta, "data metadata" + ) + self._active_frame_indexes = sorted( + set(range(self._task.size)) - set(data_meta.deleted_frames) + ) + + self._logger.info("Downloading chunks...") + + self._chunk_dir = self._task_dir / "chunks" + self._chunk_dir.mkdir(exist_ok=True, parents=True) + + needed_chunks = { + index // self._task.data_chunk_size for index in self._active_frame_indexes + } + + with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool: + for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)): + # just need to loop through all results so that any exceptions are propagated + pass + + self._logger.info("All chunks downloaded") + + self._label_id_to_index = types.MappingProxyType( + { + label["id"]: label_index + for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id)) + } + ) + + annotations = self._ensure_model( + "annotations.json", LabeledData, self._task.get_annotations, "annotations" + ) + + self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict( + FrameAnnotations + ) + + for tag in annotations.tags: + self._frame_annotations[tag.frame].tags.append(tag) + + for shape in annotations.shapes: + self._frame_annotations[shape.frame].shapes.append(shape) + + # TODO: tracks? + + def _initialize_task_dir(self) -> None: + task_json_path = self._task_dir / "task.json" + + try: + with open(task_json_path, "rb") as task_json_file: + saved_task = TaskRead._new_from_openapi_data(**json.load(task_json_file)) + except Exception: + self._logger.info("Task is not yet cached or the cache is corrupted") + + # If the cache was corrupted, the directory might already be there; clear it. + if self._task_dir.exists(): + shutil.rmtree(self._task_dir) + else: + if saved_task.updated_date < self._task.updated_date: + self._logger.info( + "Task has been updated on the server since it was cached; purging the cache" + ) + shutil.rmtree(self._task_dir) + + self._task_dir.mkdir(exist_ok=True, parents=True) + + with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file: + json.dump(to_json(self._task._model), task_json_file, indent=4) + print(file=task_json_file) # add final newline + + def _ensure_chunk(self, chunk_index: int) -> None: + chunk_path = self._chunk_dir / f"{chunk_index}.zip" + if chunk_path.exists(): + return # already downloaded previously + + self._logger.info(f"Downloading chunk #{chunk_index}...") + + with atomic_writer(chunk_path, "wb") as chunk_file: + self._task.download_chunk(chunk_index, chunk_file, quality="original") + + def _ensure_model( + self, + filename: str, + model_type: Type[_ModelType], + download: Callable[[], _ModelType], + model_description: str, + ) -> _ModelType: + path = self._task_dir / filename + + try: + with open(path, "rb") as f: + model = model_type._new_from_openapi_data(**json.load(f)) + self._logger.info(f"Loaded {model_description} from cache") + return model + except FileNotFoundError: + pass + except Exception: + self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True) + + self._logger.info(f"Downloading {model_description}...") + model = download() + self._logger.info(f"Downloaded {model_description}") + + with atomic_writer(path, "w", encoding="UTF-8") as f: + json.dump(to_json(model), f, indent=4) + print(file=f) # add final newline + + return model + + def __getitem__(self, sample_index: int): + """ + Returns the sample with index `sample_index`. + + `sample_index` must satisfy the condition `0 <= sample_index < len(self)`. + """ + + frame_index = self._active_frame_indexes[sample_index] + chunk_index = frame_index // self._task.data_chunk_size + member_index = frame_index % self._task.data_chunk_size + + with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip: + with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member: + sample_image = PIL.Image.open(chunk_member) + sample_image.load() + + sample_target = Target( + annotations=self._frame_annotations[frame_index], + label_id_to_index=self._label_id_to_index, + ) + + if self.transforms: + sample_image, sample_target = self.transforms(sample_image, sample_target) + return sample_image, sample_target + + def __len__(self) -> int: + """Returns the number of samples in the dataset.""" + return len(self._active_frame_indexes) + + +@attrs.frozen +class ExtractSingleLabelIndex: + """ + A target transform that takes a `Target` object and produces a single label index + based on the tag in that object. + + This makes the dataset samples compatible with the image classification networks + in torchvision. + + If the annotations contain no tags, or multiple tags, raises a `ValueError`. + """ + + def __call__(self, target: Target) -> int: + tags = target.annotations.tags + if not tags: + raise ValueError("sample has no tags") + + if len(tags) > 1: + raise ValueError("sample has multiple tags") + + return target.label_id_to_index[tags[0].label_id] + + +class LabeledBoxes(TypedDict): + boxes: Sequence[Tuple[float, float, float, float]] + labels: Sequence[int] + + +_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"]) + + +@attrs.frozen +class ExtractBoundingBoxes: + """ + A target transform that takes a `Target` object and returns a dictionary compatible + with the object detection networks in torchvision. + + The dictionary contains the following entries: + + "boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape + in the annotations. + "labels": a sequence of corresponding label indices. + + Limitations: + + * Only the following shape types are supported: rectangle, polygon, polyline, + points, ellipse. + * Rotated shapes are not supported. + + Unsupported shapes will cause a `UnsupportedDatasetError` exception to be + raised unless they are filtered out by `include_shape_types`. + """ + + include_shape_types: FrozenSet[str] = attrs.field( + converter=frozenset, + validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)), + kw_only=True, + ) + """Shapes whose type is not in this set will be ignored.""" + + def __call__(self, target: Target) -> LabeledBoxes: + boxes = [] + labels = [] + + for shape in target.annotations.shapes: + if shape.type.value not in self.include_shape_types: + continue + + if shape.rotation != 0: + raise UnsupportedDatasetError("Rotated shapes are not supported") + + x_coords = shape.points[0::2] + y_coords = shape.points[1::2] + + boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords))) + labels.append(target.label_id_to_index[shape.label_id]) + + return LabeledBoxes(boxes=boxes, labels=labels) diff --git a/cvat-sdk/gen/templates/openapi-generator/setup.mustache b/cvat-sdk/gen/templates/openapi-generator/setup.mustache index 40cd4ac45a81..13c2a9535966 100644 --- a/cvat-sdk/gen/templates/openapi-generator/setup.mustache +++ b/cvat-sdk/gen/templates/openapi-generator/setup.mustache @@ -76,6 +76,9 @@ setup( ], python_requires="{{{generatorLanguageVersion}}}", install_requires=BASE_REQUIREMENTS, + extras_require={ + "pytorch": ['appdirs', 'torch', 'torchvision'], + }, package_dir={"": "."}, packages=find_packages(include=["cvat_sdk*"]), include_package_data=True, diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py new file mode 100644 index 000000000000..69d329b72688 --- /dev/null +++ b/tests/python/sdk/test_pytorch.py @@ -0,0 +1,207 @@ +# Copyright (C) 2022 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import io +import os +from logging import Logger +from pathlib import Path +from typing import Tuple + +import pytest +from cvat_sdk import Client, models +from cvat_sdk.core.proxies.tasks import ResourceType + +try: + import cvat_sdk.pytorch as cvatpt + import PIL.Image + import torch + import torchvision.transforms + import torchvision.transforms.functional as TF + from torch.utils.data import DataLoader +except ImportError: + cvatpt = None + +from shared.utils.helpers import generate_image_files + + +@pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") +class TestTaskVisionDataset: + @pytest.fixture(autouse=True) + def setup( + self, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + fxt_login: Tuple[Client, str], + fxt_logger: Tuple[Logger, io.StringIO], + fxt_stdout: io.StringIO, + ): + self.tmp_path = tmp_path + logger, self.logger_stream = fxt_logger + self.stdout = fxt_stdout + self.client, self.user = fxt_login + self.client.logger = logger + + api_client = self.client.api_client + for k in api_client.configuration.logger: + api_client.configuration.logger[k] = logger + + monkeypatch.setattr(cvatpt, "_CACHE_DIR", self.tmp_path / "cache") + + self._create_task() + + yield + + def _create_task(self): + self.images = generate_image_files(10) + + image_dir = self.tmp_path / "images" + image_dir.mkdir() + + image_paths = [] + for image in self.images: + image_path = image_dir / image.name + image_path.write_bytes(image.getbuffer()) + image_paths.append(image_path) + + self.task = self.client.tasks.create_from_data( + models.TaskWriteRequest( + "PyTorch integration test task", + labels=[ + models.PatchedLabelRequest(name="person"), + models.PatchedLabelRequest(name="car"), + ], + ), + ResourceType.LOCAL, + list(map(os.fspath, image_paths)), + data_params={"chunk_size": 3}, + ) + + self.label_ids = sorted(l.id for l in self.task.labels) + + self.task.update_annotations( + models.PatchedLabeledDataRequest( + tags=[ + models.LabeledImageRequest(frame=5, label_id=self.label_ids[0]), + models.LabeledImageRequest(frame=6, label_id=self.label_ids[1]), + models.LabeledImageRequest(frame=8, label_id=self.label_ids[0]), + models.LabeledImageRequest(frame=8, label_id=self.label_ids[1]), + ], + shapes=[ + models.LabeledShapeRequest( + frame=6, + label_id=self.label_ids[1], + type=models.ShapeType("rectangle"), + points=[1.0, 2.0, 3.0, 4.0], + ), + models.LabeledShapeRequest( + frame=7, + label_id=self.label_ids[0], + type=models.ShapeType("points"), + points=[1.1, 2.1, 3.1, 4.1], + ), + ], + ) + ) + + def test_basic(self): + dataset = cvatpt.TaskVisionDataset(self.client, self.task.id) + + assert len(dataset) == self.task.size + + for index, (sample_image, sample_target) in enumerate(dataset): + sample_image_tensor = TF.pil_to_tensor(sample_image) + reference_tensor = TF.pil_to_tensor(PIL.Image.open(self.images[index])) + assert torch.equal(sample_image_tensor, reference_tensor) + + for index, label_id in enumerate(self.label_ids): + assert sample_target.label_id_to_index[label_id] == index + + assert not dataset[0][1].annotations.tags + assert not dataset[0][1].annotations.shapes + + assert len(dataset[5][1].annotations.tags) == 1 + assert dataset[5][1].annotations.tags[0].label_id == self.label_ids[0] + assert not dataset[5][1].annotations.shapes + + assert len(dataset[6][1].annotations.tags) == 1 + assert dataset[6][1].annotations.tags[0].label_id == self.label_ids[1] + assert len(dataset[6][1].annotations.shapes) == 1 + assert dataset[6][1].annotations.shapes[0].type.value == "rectangle" + assert dataset[6][1].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0] + + assert not dataset[7][1].annotations.tags + assert len(dataset[7][1].annotations.shapes) == 1 + assert dataset[7][1].annotations.shapes[0].type.value == "points" + assert dataset[7][1].annotations.shapes[0].points == [1.1, 2.1, 3.1, 4.1] + + def test_deleted_frame(self): + self.task.remove_frames_by_ids([1]) + + dataset = cvatpt.TaskVisionDataset(self.client, self.task.id) + + assert len(dataset) == self.task.size - 1 + + # sample #0 is still frame #0 + assert torch.equal( + TF.pil_to_tensor(dataset[0][0]), TF.pil_to_tensor(PIL.Image.open(self.images[0])) + ) + + # sample #1 is now frame #2 + assert torch.equal( + TF.pil_to_tensor(dataset[1][0]), TF.pil_to_tensor(PIL.Image.open(self.images[2])) + ) + + # sample #4 is now frame #5 + assert len(dataset[4][1].annotations.tags) == 1 + assert dataset[4][1].annotations.tags[0].label_id == self.label_ids[0] + assert not dataset[4][1].annotations.shapes + + def test_extract_single_label_index(self): + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + transform=torchvision.transforms.PILToTensor(), + target_transform=cvatpt.ExtractSingleLabelIndex(), + ) + + assert dataset[5][1] == 0 + assert dataset[6][1] == 1 + + with pytest.raises(ValueError): + # no tags + _ = dataset[7] + + with pytest.raises(ValueError): + # multiple tags + _ = dataset[8] + + # make sure the samples can be batched with the default collater + loader = DataLoader(dataset, batch_size=2, sampler=[5, 6]) + + batch = next(iter(loader)) + assert torch.equal(batch[0][0], TF.pil_to_tensor(PIL.Image.open(self.images[5]))) + assert torch.equal(batch[0][1], TF.pil_to_tensor(PIL.Image.open(self.images[6]))) + assert torch.equal(batch[1], torch.tensor([0, 1])) + + def test_extract_bounding_boxes(self): + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + transform=torchvision.transforms.PILToTensor(), + target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}), + ) + + assert dataset[0][1] == {"boxes": [], "labels": []} + assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]} + assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out + + def test_transforms(self): + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + transforms=lambda x, y: (y, x), + ) + + assert isinstance(dataset[0][0], cvatpt.Target) + assert isinstance(dataset[0][1], PIL.Image.Image) diff --git a/tests/python/shared/utils/helpers.py b/tests/python/shared/utils/helpers.py index 872c7085237f..a8a7120c78bd 100644 --- a/tests/python/shared/utils/helpers.py +++ b/tests/python/shared/utils/helpers.py @@ -8,9 +8,9 @@ from PIL import Image -def generate_image_file(filename="image.png", size=(50, 50)): +def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)): f = BytesIO() - image = Image.new("RGB", size=size) + image = Image.new("RGB", size=size, color=color) image.save(f, "jpeg") f.name = filename f.seek(0) @@ -21,7 +21,7 @@ def generate_image_file(filename="image.png", size=(50, 50)): def generate_image_files(count) -> List[BytesIO]: images = [] for i in range(count): - image = generate_image_file(f"{i}.jpeg") + image = generate_image_file(f"{i}.jpeg", color=(i, i, i)) images.append(image) return images