Skip to content

Commit

Permalink
Update nested objects on parent labels update (#6958)
Browse files Browse the repository at this point in the history
Fixes #6871

Added batch update (note - no signals issued on this) for owning and nested objects (tasks, jobs) on parent (task, project) labels updates
  • Loading branch information
zhiltsov-max authored Oct 20, 2023
1 parent 9004b27 commit 99e4801
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 2 deletions.
4 changes: 4 additions & 0 deletions changelog.d/20231019_114611_mzhiltsov_fix_label_updates.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
### Fixed

- Label updates didn't change project/task/job update time
(<https://github.com/opencv/cvat/pull/6958>)
3 changes: 3 additions & 0 deletions cvat/apps/engine/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,9 @@ def update_or_create(self, *args, **kwargs: Any):
return super().update_or_create(*args, **kwargs)

def _validate_constraints(self, obj: Dict[str, Any]):
if 'type' not in obj:
return

# Constraints can't be set on the related model fields
# This method requires the save operation to be called as a transaction
if obj['type'] == JobType.GROUND_TRUTH and self.filter(
Expand Down
24 changes: 24 additions & 0 deletions cvat/apps/engine/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,8 +1204,17 @@ def update(self, instance, validated_data):
_update_related_storages(instance, validated_data)

instance.save()

if 'label_set' in validated_data and not instance.project_id:
self.update_child_objects_on_labels_update(instance)

return instance

def update_child_objects_on_labels_update(self, instance: models.Task):
models.Job.objects.filter(
updated_date__lt=instance.updated_date, segment__task=instance
).update(updated_date=instance.updated_date)

def validate(self, attrs):
# When moving task labels can be mapped to one, but when not names must be unique
if 'project_id' in attrs.keys() and self.instance is not None:
Expand Down Expand Up @@ -1344,8 +1353,23 @@ def update(self, instance, validated_data):
_update_related_storages(instance, validated_data)

instance.save()

if 'label_set' in validated_data:
self.update_child_objects_on_labels_update(instance)

return instance

@transaction.atomic
def update_child_objects_on_labels_update(self, instance: models.Project):
models.Task.objects.filter(
updated_date__lt=instance.updated_date, project=instance
).update(updated_date=instance.updated_date)

models.Job.objects.filter(
updated_date__lt=instance.updated_date, segment__task__project=instance
).update(updated_date=instance.updated_date)


class AboutSerializer(serializers.Serializer):
name = serializers.CharField(max_length=128)
description = serializers.CharField(max_length=2048)
Expand Down
9 changes: 8 additions & 1 deletion cvat/apps/engine/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,14 +2251,21 @@ def perform_update(self, serializer):

return super().perform_update(serializer)

def perform_destroy(self, instance):
def perform_destroy(self, instance: models.Label):
if instance.parent is not None:
# NOTE: this can be relaxed when skeleton updates are implemented properly
raise ValidationError(
"Sublabels cannot be deleted this way. "
"Please send a PATCH request with updated parent label data instead.",
code=status.HTTP_400_BAD_REQUEST)

if project := instance.project:
project.save(update_fields=['updated_date'])
ProjectWriteSerializer(project).update_child_objects_on_labels_update(project)
elif task := instance.task:
task.save(update_fields=['updated_date'])
TaskWriteSerializer(task).update_child_objects_on_labels_update(task)

return super().perform_destroy(instance)


Expand Down
1 change: 1 addition & 0 deletions tests/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ requests==2.31.0
deepdiff==5.6.0
boto3==1.17.61
Pillow==10.0.1
python-dateutil==2.8.2
pyyaml==6.0.0
numpy==1.22.0
99 changes: 98 additions & 1 deletion tests/python/rest_api/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
from cvat_sdk import exceptions, models
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from cvat_sdk.core.helpers import get_paginated_collection
from dateutil.parser import isoparse as parse_datetime
from deepdiff import DeepDiff
from pytest_cases import fixture, fixture_ref, parametrize

from shared.utils.config import make_api_client
from shared.utils.config import delete_method, get_method, make_api_client, patch_method

from .utils import CollectionSimpleFilterTestBase, build_exclude_paths_expr, get_attrs

Expand Down Expand Up @@ -823,3 +824,99 @@ def test_regular_user_delete_org_label(self, user_org_case):
self._test_delete_ok(user["username"], label["id"])
else:
self._test_delete_denied(user["username"], label["id"])


@pytest.mark.usefixtures("restore_db_per_function")
class TestLabelUpdates:
@pytest.mark.parametrize("update_kind", ["addition", "removal", "modification"])
def test_project_label_update_triggers_nested_task_and_job_update(
self, update_kind, admin_user, labels, projects_wlc, tasks, jobs
):
# Checks for regressions against the issue https://github.com/opencv/cvat/issues/6871

project = next(p for p in projects_wlc if p["tasks"]["count"] and p["labels"]["count"])
project_labels = [l for l in labels if l.get("project_id") == project["id"]]
nested_tasks = [t for t in tasks if t["project_id"] == project["id"]]
nested_task_ids = set(t["id"] for t in nested_tasks)
nested_jobs = [j for j in jobs if j["task_id"] in nested_task_ids]

if update_kind == "addition":
response = patch_method(
admin_user, f'projects/{project["id"]}', {"labels": [{"name": "dog2"}]}
)
updated_project = response.json()
elif update_kind == "modification":
label = project_labels[0]
patch_method(admin_user, f'labels/{label["id"]}', {"name": label["name"] + "-updated"})

response = get_method(admin_user, f'projects/{project["id"]}')
updated_project = response.json()
elif update_kind == "removal":
label = project_labels[0]
delete_method(admin_user, f'labels/{label["id"]}')

response = get_method(admin_user, f'projects/{project["id"]}')
updated_project = response.json()
else:
assert False

with make_api_client(admin_user) as api_client:
updated_tasks = get_paginated_collection(
api_client.tasks_api.list_endpoint, project_id=project["id"], return_json=True
)

updated_jobs = [
j
for j in get_paginated_collection(
api_client.jobs_api.list_endpoint, return_json=True
)
if j["task_id"] in nested_task_ids
]

assert parse_datetime(project["updated_date"]) < parse_datetime(
updated_project["updated_date"]
)
assert len(updated_tasks) == len(nested_tasks)
assert len(updated_jobs) == len(nested_jobs)
for entity in updated_tasks + updated_jobs:
assert updated_project["updated_date"] == entity["updated_date"]

@pytest.mark.parametrize("update_kind", ["addition", "removal", "modification"])
def test_task_label_update_triggers_nested_task_and_job_update(
self, update_kind, admin_user, labels, tasks_wlc, jobs
):
# Checks for regressions against the issue https://github.com/opencv/cvat/issues/6871

task = next(t for t in tasks_wlc if t["jobs"]["count"] and t["labels"]["count"])
task_labels = [l for l in labels if l.get("task_id") == task["id"]]
nested_jobs = [j for j in jobs if j["task_id"] == task["id"]]

if update_kind == "addition":
response = patch_method(
admin_user, f'tasks/{task["id"]}', {"labels": [{"name": "dog2"}]}
)
updated_task = response.json()
elif update_kind == "modification":
label = task_labels[0]
patch_method(admin_user, f'labels/{label["id"]}', {"name": label["name"] + "-updated"})

response = get_method(admin_user, f'tasks/{task["id"]}')
updated_task = response.json()
elif update_kind == "removal":
label = task_labels[0]
delete_method(admin_user, f'labels/{label["id"]}')

response = get_method(admin_user, f'tasks/{task["id"]}')
updated_task = response.json()
else:
assert False

with make_api_client(admin_user) as api_client:
updated_jobs = get_paginated_collection(
api_client.jobs_api.list_endpoint, task_id=task["id"], return_json=True
)

assert parse_datetime(task["updated_date"]) < parse_datetime(updated_task["updated_date"])
assert len(updated_jobs) == len(nested_jobs)
for job in updated_jobs:
assert updated_task["updated_date"] == job["updated_date"]

0 comments on commit 99e4801

Please sign in to comment.