Skip to content

Commit

Permalink
Fix task creation with gt_pool validation and cloud storage data (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Marishka17 authored Oct 16, 2024
1 parent c557f70 commit 49ec1d1
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Fixed

- Task creation with cloud storage data and GT_POOL validation mode
(<https://github.com/cvat-ai/cvat/pull/8539>)
28 changes: 3 additions & 25 deletions cvat/apps/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1157,19 +1157,6 @@ def _update_status(msg: str) -> None:
assert job_file_mapping[-1] == validation_params['frames']
job_file_mapping.pop(-1)

# Update manifest
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.link(
sources=[extractor.get_path(image.frame) for image in images],
meta={
k: {'related_images': related_images[k] }
for k in related_images
},
data_dir=upload_dir,
DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D),
)
manifest.create()

db_data.update_validation_layout(models.ValidationLayout(
mode=models.ValidationMode.GT_POOL,
frames=list(frame_idx_map.values()),
Expand Down Expand Up @@ -1324,24 +1311,15 @@ def _update_status(msg: str) -> None:
assert image.is_placeholder
image.real_frame = frame_id_map[image.real_frame]

# Update manifest
manifest.reorder([images[frame_idx_map[image.frame]].path for image in new_db_images])

images = new_db_images
db_data.size = len(images)
db_data.start_frame = 0
db_data.stop_frame = 0
db_data.frame_filter = ''

# Update manifest
manifest = ImageManifestManager(db_data.get_manifest_path())
manifest.link(
sources=[extractor.get_path(frame_idx_map[image.frame]) for image in images],
meta={
k: {'related_images': related_images[k] }
for k in related_images
},
data_dir=upload_dir,
DIM_3D=(db_task.dimension == models.DimensionType.DIM_3D),
)
manifest.create()

db_data.update_validation_layout(models.ValidationLayout(
mode=models.ValidationMode.GT_POOL,
Expand Down
95 changes: 89 additions & 6 deletions tests/python/rest_api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ClassVar,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Expand Down Expand Up @@ -1529,12 +1530,13 @@ def _create_task_with_cloud_data(
server_files: List[str],
use_cache: bool = True,
sorting_method: str = "lexicographical",
spec: Optional[Dict[str, Any]] = None,
data_type: str = "image",
video_frame_count: int = 10,
server_files_exclude: Optional[List[str]] = None,
org: Optional[str] = None,
org: str = "",
filenames: Optional[List[str]] = None,
task_spec_kwargs: Optional[Dict[str, Any]] = None,
data_spec_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[int, Any]:
s3_client = s3.make_client(bucket=cloud_storage["resource"])
if data_type == "video":
Expand All @@ -1551,7 +1553,9 @@ def _create_task_with_cloud_data(
)
else:
images = generate_image_files(
3, **({"prefixes": ["img_"] * 3} if not filenames else {"filenames": filenames})
3,
sizes=[(100, 50) if i % 2 else (50, 100) for i in range(3)],
**({"prefixes": ["img_"] * 3} if not filenames else {"filenames": filenames}),
)

for image in images:
Expand Down Expand Up @@ -1598,6 +1602,7 @@ def _create_task_with_cloud_data(
"name": "car",
}
],
**(task_spec_kwargs or {}),
}

data_spec = {
Expand All @@ -1608,9 +1613,8 @@ def _create_task_with_cloud_data(
server_files if not use_manifest else server_files + ["test/manifest.jsonl"]
),
"sorting_method": sorting_method,
**(data_spec_kwargs or {}),
}
if spec is not None:
data_spec.update(spec)

if server_files_exclude:
data_spec["server_files_exclude"] = server_files_exclude
Expand Down Expand Up @@ -1984,7 +1988,7 @@ def test_create_task_with_cloud_storage_and_check_retrieve_data_meta(
use_cache=False,
server_files=["test/video/video.avi"],
org=org,
spec=data_spec,
data_spec_kwargs=data_spec,
data_type="video",
)

Expand Down Expand Up @@ -2550,6 +2554,85 @@ def test_can_create_task_with_gt_job_from_video(
else:
assert len(validation_frames) == validation_frames_count

@pytest.mark.with_external_services
@pytest.mark.parametrize("cloud_storage_id", [2])
@pytest.mark.parametrize(
"validation_mode",
[
models.ValidationMode("gt"),
models.ValidationMode("gt_pool"),
],
)
def test_can_create_task_with_validation_and_cloud_data(
self,
cloud_storage_id: int,
validation_mode: models.ValidationMode,
request: pytest.FixtureRequest,
admin_user: str,
cloud_storages: Iterable,
):
cloud_storage = cloud_storages[cloud_storage_id]
server_files = [f"test/sub_0/img_{i}.jpeg" for i in range(3)]
validation_frames = ["test/sub_0/img_1.jpeg"]

(task_id, _) = self._create_task_with_cloud_data(
request,
cloud_storage,
use_manifest=False,
server_files=server_files,
sorting_method=models.SortingMethod(
"random"
), # only random sorting can be used with gt_pool
data_spec_kwargs={
"validation_params": models.DataRequestValidationParams._from_openapi_data(
mode=validation_mode,
frames=validation_frames,
frame_selection_method=models.FrameSelectionMethod("manual"),
frames_per_job_count=1,
)
},
task_spec_kwargs={
# in case of gt_pool: each regular job will contain 1 regular and 1 validation frames,
# (number of validation frames is not included into segment_size)
"segment_size": 1,
},
)

with make_api_client(admin_user) as api_client:
# check that GT job was created
(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="ground_truth")
assert 1 == len(paginated_jobs["results"])

(paginated_jobs, _) = api_client.jobs_api.list(task_id=task_id, type="annotation")
jobs_count = (
len(server_files) - len(validation_frames)
if validation_mode == models.ValidationMode("gt_pool")
else len(server_files)
)
assert jobs_count == len(paginated_jobs["results"])
# check that the returned meta of images corresponds to the chunk data
# Note: meta is based on the order of images from database
# while chunk with CS data is based on the order of images in a manifest
for job in paginated_jobs["results"]:
(job_meta, _) = api_client.jobs_api.retrieve_data_meta(job["id"])
(_, response) = api_client.jobs_api.retrieve_data(
job["id"], type="chunk", quality="compressed", index=0
)
chunk_file = io.BytesIO(response.data)
assert zipfile.is_zipfile(chunk_file)

with zipfile.ZipFile(chunk_file, "r") as chunk_archive:
chunk_images = {
int(os.path.splitext(name)[0]): np.array(
Image.open(io.BytesIO(chunk_archive.read(name)))
)
for name in chunk_archive.namelist()
}
chunk_images = dict(sorted(chunk_images.items(), key=lambda e: e[0]))

for img, img_meta in zip(chunk_images.values(), job_meta.frames):
assert (img.shape[0], img.shape[1]) == (img_meta.height, img_meta.width)


class _SourceDataType(str, Enum):
images = "images"
Expand Down
12 changes: 9 additions & 3 deletions tests/python/shared/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
from contextlib import closing
from io import BytesIO
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Tuple

import av
import av.video.reformatter
Expand All @@ -25,7 +25,11 @@ def generate_image_file(filename="image.png", size=(100, 50), color=(0, 0, 0)):


def generate_image_files(
count, prefixes=None, *, filenames: Optional[List[str]] = None
count: int,
*,
prefixes: Optional[List[str]] = None,
filenames: Optional[List[str]] = None,
sizes: Optional[List[Tuple[int, int]]] = None,
) -> List[BytesIO]:
assert not (prefixes and filenames), "prefixes cannot be used together with filenames"
assert not prefixes or len(prefixes) == count
Expand All @@ -35,7 +39,9 @@ def generate_image_files(
for i in range(count):
prefix = prefixes[i] if prefixes else ""
filename = f"{prefix}{i}.jpeg" if not filenames else filenames[i]
image = generate_image_file(filename, color=(i, i, i))
image = generate_image_file(
filename, color=(i, i, i), **({"size": sizes[i]}) if sizes else {}
)
images.append(image)

return images
Expand Down
15 changes: 15 additions & 0 deletions utils/dataset_manifest/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,21 @@ def emulate_hierarchical_structure(
'next': next_start_index,
}

def reorder(self, reordered_images: List[str]) -> None:
"""
The method takes a list of image names and reorders its content based on this new list.
Due to the implementation of Honeypots, the reordered list of image names may contain duplicates.
"""
unique_images: Dict[str, Any] = {}
for _, image_details in self:
if image_details.full_name not in unique_images:
unique_images[image_details.full_name] = image_details

try:
self.create(content=(unique_images[x] for x in reordered_images))
except KeyError as ex:
raise InvalidManifestError(f"Previous manifest does not contain {ex} image")

class _BaseManifestValidator(ABC):
def __init__(self, full_manifest_path):
self._manifest = _Manifest(full_manifest_path)
Expand Down

0 comments on commit 49ec1d1

Please sign in to comment.