Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix file matching in annotation import for multiple dots in filenames #6350

Merged
merged 12 commits into from
Jun 29, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- \[API\] Invalid schema for the owner field in several endpoints (<https://github.com/opencv/cvat/pull/6343>)
- \[SDK\] Loading tasks that have been cached with the PyTorch adapter
(<https://github.com/opencv/cvat/issues/6047>)
- The problem with importing annotations if dataset has extra dots in filenames (<https://github.com/opencv/cvat/pull/6350>)

### Security
- TDB
Expand Down
60 changes: 43 additions & 17 deletions cvat/apps/dataset_manager/bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self,
self._create_callback = create_callback
self._MAX_ANNO_SIZE = 30000
self._frame_info = {}
self._frame_mapping = {}
self._frame_mapping: Dict[str, int] = {}
self._frame_step = db_task.data.get_frame_step()
self._db_data = db_task.data
self._use_server_track_ids = use_server_track_ids
Expand Down Expand Up @@ -613,28 +613,37 @@ def __len__(self):
raise NotImplementedError()

@staticmethod
def _get_filename(path):
def _get_filename(path: str) -> str:
return osp.splitext(path)[0]

def match_frame(self, path, root_hint=None, path_has_ext=True):
def match_frame(self,
path: str, root_hint: Optional[str] = None, *, path_has_ext: bool = True
) -> Optional[int]:
if path_has_ext:
path = self._get_filename(path)

match = self._frame_mapping.get(path)

if not match and root_hint and not path.startswith(root_hint):
path = osp.join(root_hint, path)
match = self._frame_mapping.get(path)

return match

def match_frame_fuzzy(self, path):
def match_frame_fuzzy(self, path: str, *, path_has_ext: bool = True) -> Optional[int]:
# Preconditions:
# - The input dataset is full, i.e. all items present. Partial dataset
# matching can't be correct for all input cases.
# - path is the longest path of input dataset in terms of path parts

path = Path(self._get_filename(path)).parts
if path_has_ext:
path = self._get_filename(path)

path = Path(path).parts
for p, v in self._frame_mapping.items():
if Path(p).parts[-len(path):] == path: # endswith() for paths
return v

return None

class JobData(CommonData):
Expand Down Expand Up @@ -1254,20 +1263,30 @@ def task_data(self):
def _get_filename(path):
return osp.splitext(path)[0]

def match_frame(self, path: str, subset: str=dm.DEFAULT_SUBSET_NAME, root_hint: str=None, path_has_ext: bool=True):
def match_frame(self,
path: str, subset: str = dm.DEFAULT_SUBSET_NAME,
root_hint: str = None, path_has_ext: bool = True
) -> Optional[int]:
if path_has_ext:
path = self._get_filename(path)

match_task, match_frame = self._frame_mapping.get((subset, path), (None, None))

if not match_frame and root_hint and not path.startswith(root_hint):
path = osp.join(root_hint, path)
match_task, match_frame = self._frame_mapping.get((subset, path), (None, None))

return match_task, match_frame

def match_frame_fuzzy(self, path):
path = Path(self._get_filename(path)).parts
def match_frame_fuzzy(self, path: str, *, path_has_ext: bool = True) -> Optional[int]:
if path_has_ext:
path = self._get_filename(path)

path = Path(path).parts
for (_subset, _path), (_tid, frame_number) in self._frame_mapping.items():
if Path(_path).parts[-len(path):] == path :
return frame_number

return None

def split_dataset(self, dataset: dm.Dataset):
Expand Down Expand Up @@ -1814,7 +1833,11 @@ def convert_cvat_anno_to_dm(
return converter.convert()


def match_dm_item(item, instance_data, root_hint=None):
def match_dm_item(
item: dm.DatasetItem,
instance_data: Union[ProjectData, CommonData],
root_hint: Optional[str] = None
) -> int:
is_video = instance_data.meta[instance_data.META_FIELD]['mode'] == 'interpolation'

frame_number = None
Expand All @@ -1832,20 +1855,23 @@ def match_dm_item(item, instance_data, root_hint=None):
"'%s' with any task frame" % item.id)
return frame_number

def find_dataset_root(dm_dataset, instance_data: Union[ProjectData, CommonData]):
longest_path = max(dm_dataset, key=lambda x: len(Path(x.id).parts),
default=None)
if longest_path is None:
def find_dataset_root(
dm_dataset: dm.IDataset, instance_data: Union[ProjectData, CommonData]
) -> Optional[str]:
longest_path_item = max(dm_dataset, key=lambda item: len(Path(item.id).parts), default=None)
if longest_path_item is None:
return None
longest_path = longest_path.id
longest_path = longest_path_item.id

longest_match = instance_data.match_frame_fuzzy(longest_path)
if longest_match is None:
matched_frame_number = instance_data.match_frame_fuzzy(longest_path, path_has_ext=False)
if matched_frame_number is None:
return None
longest_match = osp.dirname(instance_data.frame_info[longest_match]['path'])

longest_match = osp.dirname(instance_data.frame_info[matched_frame_number]['path'])
prefix = longest_match[:-len(osp.dirname(longest_path)) or None]
if prefix.endswith('/'):
prefix = prefix[:-1]

return prefix

def import_dm_annotations(dm_dataset: dm.Dataset, instance_data: Union[ProjectData, CommonData]):
Expand Down
11 changes: 5 additions & 6 deletions cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2917,12 +2917,11 @@ def _import_annotations(request, rq_id_template, rq_func, db_obj, format_name,
elif rq_job.is_failed or \
rq_job.is_deferred and rq_job.dependency and rq_job.dependency.is_failed:
exc_info = process_failed_job(rq_job)
# RQ adds a prefix with exception class name
import_error_prefix = '{}.{}'.format(
CvatImportError.__module__, CvatImportError.__name__)
if import_error_prefix in exc_info:
return Response(data="The annotations that were uploaded are not correct",
status=status.HTTP_400_BAD_REQUEST)

import_error_prefix = f'{CvatImportError.__module__}.{CvatImportError.__name__}:'
if exc_info.startswith("Traceback") and import_error_prefix in exc_info:
exc_message = exc_info.split(import_error_prefix)[-1].strip()
return Response(data=exc_message, status=status.HTTP_400_BAD_REQUEST)
else:
return Response(data=exc_info,
status=status.HTTP_500_INTERNAL_SERVER_ERROR)
Expand Down
151 changes: 151 additions & 0 deletions tests/python/rest_api/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,3 +2172,154 @@ def test_check_import_cache_after_previous_interrupted_upload(self, tasks_with_s
if not number_of_files:
break
assert not number_of_files


class TestImportWithComplexFilenames:
@staticmethod
def _make_client() -> Client:
return Client(BASE_URL, config=Config(status_check_period=0.01))

@pytest.fixture(
autouse=True,
scope="class",
# classmethod way may not work in some versions
# https://github.com/opencv/cvat/actions/runs/5336023573/jobs/9670573955?pr=6350
name="TestImportWithComplexFilenames.setup_class",
)
@classmethod
def setup_class(
cls, restore_db_per_class, tmp_path_factory: pytest.TempPathFactory, admin_user: str
):
cls.tmp_dir = tmp_path_factory.mktemp(cls.__class__.__name__)
cls.client = cls._make_client()
cls.user = admin_user
cls.format_name = "PASCAL VOC 1.1"

with cls.client:
cls.client.login((cls.user, USER_PASS))

cls._init_tasks()

@classmethod
def _create_task_with_annotations(cls, filenames: List[str]):
images = generate_image_files(len(filenames), filenames=filenames)

source_archive_path = cls.tmp_dir / "source_data.zip"
with zipfile.ZipFile(source_archive_path, "w") as zip_file:
for image in images:
zip_file.writestr(image.name, image.getvalue())

task = cls.client.tasks.create_from_data(
{
"name": "test_images_with_dots",
"labels": [{"name": "cat"}, {"name": "dog"}],
},
resources=[source_archive_path],
)

labels = task.get_labels()
task.set_annotations(
models.LabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
frame=frame_id,
label_id=labels[0].id,
type="rectangle",
points=[1, 1, 2, 2],
)
for frame_id in range(len(filenames))
],
)
)

return task

@classmethod
def _init_tasks(cls):
cls.flat_filenames = [
"filename0.jpg",
"file.name1.jpg",
"fi.le.na.me.2.jpg",
".filename3.jpg",
"..filename..4.jpg",
"..filename..5.png..jpg",
]

cls.nested_filenames = [
f"{prefix}/{fn}"
for prefix, fn in zip(
[
"ab/cd",
"ab/cd",
"ab",
"ab",
"cd/ef",
"cd/ef",
"cd",
"",
],
cls.flat_filenames,
)
]

cls.data = {}
for (kind, filenames), prefix in product(
[("flat", cls.flat_filenames), ("nested", cls.nested_filenames)], ["", "pre/fix"]
):
key = kind
if prefix:
key += "_prefixed"

task = cls._create_task_with_annotations(
[f"{prefix}/{fn}" if prefix else fn for fn in filenames]
)

dataset_file = cls.tmp_dir / f"{key}_dataset.zip"
task.export_dataset(cls.format_name, dataset_file, include_images=False)

cls.data[key] = (task, dataset_file)

@pytest.mark.parametrize(
"task_kind, annotation_kind, expect_success",
[
("flat", "flat", True),
("flat", "flat_prefixed", False),
("flat", "nested", False),
("flat", "nested_prefixed", False),
("flat_prefixed", "flat", True), # allow this for better UX
("flat_prefixed", "flat_prefixed", True),
("flat_prefixed", "nested", False),
("flat_prefixed", "nested_prefixed", False),
("nested", "flat", False),
("nested", "flat_prefixed", False),
("nested", "nested", True),
("nested", "nested_prefixed", False),
("nested_prefixed", "flat", False),
("nested_prefixed", "flat_prefixed", False),
("nested_prefixed", "nested", True), # allow this for better UX
("nested_prefixed", "nested_prefixed", True),
],
)
def test_import_annotations(self, task_kind, annotation_kind, expect_success):
# Tests for regressions about https://github.com/opencv/cvat/issues/6319
#
# X annotations must be importable to X prefixed cases
# with and without dots in filenames.
#
# Nested structures can potentially be matched to flat ones and vise-versa,
# but it's not supported now, as it may lead to some errors in matching.

task: Task = self.data[task_kind][0]
dataset_file = self.data[annotation_kind][1]

if expect_success:
task.import_annotations(self.format_name, dataset_file)

assert set(s.frame for s in task.get_annotations().shapes) == set(
range(len(self.flat_filenames))
)
else:
with pytest.raises(exceptions.ApiException) as capture:
task.import_annotations(self.format_name, dataset_file)

assert b"Could not match item id" in capture.value.body
13 changes: 10 additions & 3 deletions tests/python/shared/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import subprocess
from io import BytesIO
from typing import List
from typing import List, Optional

from PIL import Image

Expand All @@ -21,11 +21,18 @@ def generate_image_file(filename="image.png", size=(50, 50), color=(0, 0, 0)):
return f


def generate_image_files(count, prefixes=None) -> List[BytesIO]:
def generate_image_files(
count, prefixes=None, *, filenames: Optional[List[str]] = None
) -> List[BytesIO]:
assert not (prefixes and filenames), "prefixes cannot be used together with filenames"
assert not prefixes or len(prefixes) == count
assert not filenames or len(filenames) == count

images = []
for i in range(count):
prefix = prefixes[i] if prefixes else ""
image = generate_image_file(f"{prefix}{i}.jpeg", color=(i, i, i))
filename = f"{prefix}{i}.jpeg" if not filenames else filenames[i]
image = generate_image_file(filename, color=(i, i, i))
images.append(image)

return images
Expand Down