diff --git a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py index 767fb07fe962..32034d496584 100644 --- a/cvat/apps/dataset_manager/tests/test_rest_api_formats.py +++ b/cvat/apps/dataset_manager/tests/test_rest_api_formats.py @@ -26,7 +26,6 @@ from attr import define, field from datumaro.components.dataset import Dataset -from datumaro.util.test_utils import compare_datasets, TestDir from django.contrib.auth.models import Group, User from PIL import Image from rest_framework import status @@ -34,6 +33,7 @@ import cvat.apps.dataset_manager as dm from cvat.apps.dataset_manager.bindings import CvatTaskOrJobDataExtractor, TaskData from cvat.apps.dataset_manager.task import TaskAnnotation +from cvat.apps.dataset_manager.tests.utils import compare_datasets, TestDir from cvat.apps.dataset_manager.util import get_export_cache_lock from cvat.apps.dataset_manager.views import clear_export_cache, export, parse_export_file_path from cvat.apps.engine.models import Task @@ -1034,7 +1034,7 @@ def test_api_v2_tasks_annotations_dump_and_upload_many_jobs_with_datumaro(self): # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) - compare_datasets(self, data_from_task_before_upload, data_from_task_after_upload) + compare_datasets(data_from_task_before_upload, data_from_task_after_upload) def test_api_v2_tasks_annotations_dump_and_upload_with_datumaro(self): test_name = self._testMethodName @@ -1110,7 +1110,7 @@ def test_api_v2_tasks_annotations_dump_and_upload_with_datumaro(self): # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) - compare_datasets(self, data_from_task_before_upload, data_from_task_after_upload) + compare_datasets(data_from_task_before_upload, data_from_task_after_upload) def test_api_v2_check_duplicated_polygon_points(self): test_name = self._testMethodName @@ -1176,7 +1176,7 @@ def test_api_v2_check_widerface_with_all_attributes(self): # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) - compare_datasets(self, data_from_task_before_upload, data_from_task_after_upload) + compare_datasets(data_from_task_before_upload, data_from_task_after_upload) def test_api_v2_check_mot_with_shapes_only(self): test_name = self._testMethodName @@ -1212,7 +1212,7 @@ def test_api_v2_check_mot_with_shapes_only(self): # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) - compare_datasets(self, data_from_task_before_upload, data_from_task_after_upload) + compare_datasets(data_from_task_before_upload, data_from_task_after_upload) def test_api_v2_check_attribute_import_in_tracks(self): test_name = self._testMethodName @@ -1249,7 +1249,7 @@ def test_api_v2_check_attribute_import_in_tracks(self): # equals annotations data_from_task_after_upload = self._get_data_from_task(task_id, include_images) - compare_datasets(self, data_from_task_before_upload, data_from_task_after_upload) + compare_datasets(data_from_task_before_upload, data_from_task_after_upload) class ExportBehaviorTest(_DbTestBase): @define diff --git a/cvat/apps/dataset_manager/tests/utils.py b/cvat/apps/dataset_manager/tests/utils.py new file mode 100644 index 000000000000..a02484878b87 --- /dev/null +++ b/cvat/apps/dataset_manager/tests/utils.py @@ -0,0 +1,85 @@ +# Copyright (C) 2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import os +import tempfile +import unittest +from typing import Optional + +from datumaro import IDataset +from datumaro.components.operations import ExactComparator +from datumaro.util.os_util import rmfile, rmtree + +from cvat.apps.dataset_manager.util import current_function_name + + +class FileRemover: + def __init__(self, path, is_dir=False): + self.path = path + self.is_dir = is_dir + + def __enter__(self): + return self.path + + def __exit__(self, exc_type=None, exc_value=None, traceback=None): + if self.is_dir: + try: + rmtree(self.path) + except unittest.SkipTest: + # Suppress skip test errors from git.util.rmtree + if not exc_type: + raise + else: + rmfile(self.path) + + +class TestDir(FileRemover): + """ + Creates a temporary directory for a test. Uses the name of + the test function to name the directory. + + Usage: + + .. code-block:: + + with TestDir() as test_dir: + ... + """ + + def __init__(self, path: Optional[str] = None, frame_id: int = 2): + if not path: + prefix = f"temp_{current_function_name(frame_id)}-" + else: + prefix = None + self._prefix = prefix + + super().__init__(path, is_dir=True) + + def __enter__(self) -> str: + """ + Creates a test directory. + + Returns: path to the directory + """ + + path = self.path + + if path is None: + path = tempfile.mkdtemp(dir=os.getcwd(), prefix=self._prefix) + self.path = path + else: + os.makedirs(path, exist_ok=False) + + return path + + +def compare_datasets(expected: IDataset, actual: IDataset): + comparator = ExactComparator() + _, unmatched, expected_extra, actual_extra, errors = comparator.compare_datasets( + expected, actual + ) + assert not unmatched, f"Datasets have unmatched items: {unmatched}" + assert not actual_extra, f"Actual has following extra items: {actual_extra}" + assert not expected_extra, f"Actual has following extra items: {expected_extra}" + assert not errors, f"There were following errors while comparing datasets: {errors}" diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index 47758be11d15..79cb2f516db3 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -37,7 +37,8 @@ from rest_framework import status from rest_framework.test import APIClient -from datumaro.util.test_utils import current_function_name, TestDir +from cvat.apps.dataset_manager.tests.utils import TestDir +from cvat.apps.dataset_manager.util import current_function_name from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job, Project, Segment, StageChoice, StatusChoice, Task, Label, StorageMethodChoice, StorageChoice, DimensionType, SortingMethod) diff --git a/cvat/apps/engine/tests/test_rest_api_3D.py b/cvat/apps/engine/tests/test_rest_api_3D.py index a67a79109f33..19f543f12dc2 100644 --- a/cvat/apps/engine/tests/test_rest_api_3D.py +++ b/cvat/apps/engine/tests/test_rest_api_3D.py @@ -20,9 +20,9 @@ from django.contrib.auth.models import Group, User from rest_framework import status +from cvat.apps.dataset_manager.tests.utils import TestDir from cvat.apps.engine.media_extractors import ValidateDimension from cvat.apps.dataset_manager.task import TaskAnnotation -from datumaro.util.test_utils import TestDir from cvat.apps.engine.tests.utils import get_paginated_collection, ApiTestBase, ForceLogin