diff --git a/cvat/apps/engine/tests/test_rest_api.py b/cvat/apps/engine/tests/test_rest_api.py index a5c279609c05..f07df7066f07 100644 --- a/cvat/apps/engine/tests/test_rest_api.py +++ b/cvat/apps/engine/tests/test_rest_api.py @@ -28,6 +28,7 @@ from rest_framework import status from rest_framework.test import APIClient, APITestCase +from datumaro.util.test_utils import TestDir from cvat.apps.engine.models import (AttributeSpec, AttributeType, Data, Job, Project, Segment, StatusChoice, Task, Label, StorageMethodChoice, StorageChoice) from cvat.apps.engine.media_extractors import ValidateDimension @@ -5283,3 +5284,98 @@ def test_api_v1_server_share_observer(self): def test_api_v1_server_share_no_auth(self): response = self._run_api_v1_server_share(None, "/") self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + +class TaskAnnotation2DContext(APITestCase): + def setUp(self): + self.client = APIClient() + self.task = { + "name": "my archive task without copying #11", + "overlap": 0, + "segment_size": 0, + "labels": [ + {"name": "car"}, + {"name": "person"}, + ] + } + + @classmethod + def setUpTestData(cls): + create_db_users(cls) + + def _get_request_with_data(self, path, data, user): + with ForceLogin(user, self.client): + response = self.client.get(path, data) + return response + + def _get_request(self, path, user): + with ForceLogin(user, self.client): + response = self.client.get(path) + return response + + def _create_task(self, data, image_data): + with ForceLogin(self.user, self.client): + response = self.client.post('/api/v1/tasks', data=data, format="json") + assert response.status_code == status.HTTP_201_CREATED, response.status_code + tid = response.data["id"] + + response = self.client.post("/api/v1/tasks/%s/data" % tid, + data=image_data) + assert response.status_code == status.HTTP_202_ACCEPTED, response.status_code + + response = self.client.get("/api/v1/tasks/%s" % tid) + task = response.data + + return task + + def create_zip_archive_with_related_images(self, file_name, test_dir, context_images_info): + with tempfile.TemporaryDirectory() as tmp_dir: + for img in context_images_info: + image = Image.new('RGB', size=(100, 50)) + image.save(osp.join(tmp_dir, img), 'png') + if context_images_info[img]: + related_path = osp.join(tmp_dir, "related_images", img.replace(".", "_")) + os.makedirs(related_path) + image.save(osp.join(related_path, f"related_{img}"), 'png') + + zip_file_path = osp.join(test_dir, file_name) + shutil.make_archive(zip_file_path, 'zip', tmp_dir) + return f"{zip_file_path}.zip" + + def test_check_flag_has_related_context(self): + with TestDir() as test_dir: + test_cases = { + "All images with context": {"image_1.png": True, "image_2.png": True}, + "One image with context": {"image_1.png": True, "image_2.png": False} + } + for test_case, context_img_data in test_cases.items(): + filename = self.create_zip_archive_with_related_images(test_case, test_dir, context_img_data) + img_data = { + "client_files[0]": open(filename, 'rb'), + "image_quality": 75, + } + task = self._create_task(self.task, img_data) + task_id = task["id"] + + response = self._get_request("/api/v1/tasks/%s/data/meta" % task_id, self.admin) + for frame in response.data["frames"]: + self.assertEqual(context_img_data[frame["name"]], frame["has_related_context"]) + + def test_fetch_related_image_from_server(self): + test_name = self._testMethodName + context_img_data ={"image_1.png": True} + with TestDir() as test_dir: + filename = self.create_zip_archive_with_related_images(test_name, test_dir, context_img_data) + img_data = { + "client_files[0]": open(filename, 'rb'), + "image_quality": 75, + } + task = self._create_task(self.task , img_data) + task_id = task["id"] + data = { + "quality": "original", + "type": "context_image", + "number": 0 + } + response = self._get_request_with_data("/api/v1/tasks/%s/data" % task_id, data, self.admin) + self.assertEqual(response.status_code, status.HTTP_200_OK)