Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: LEAP-24: /api/tasks performance improvement #4738

Merged
merged 7 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions label_studio/data_manager/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
"""
import logging

from asgiref.sync import async_to_sync, sync_to_async
from core.feature_flags import flag_set
from core.permissions import ViewClassPermission, all_permissions
from core.utils.common import int_from_request, load_func
from core.utils.params import bool_from_request
Expand Down Expand Up @@ -135,11 +137,26 @@ class TaskPagination(PageNumberPagination):
total_predictions = 0
max_page_size = settings.TASK_API_PAGE_SIZE_MAX

def paginate_queryset(self, queryset, request, view=None):
@async_to_sync
async def async_paginate_queryset(self, queryset, request, view=None):
predictions_count_qs = Prediction.objects.filter(task_id__in=queryset)
self.total_predictions = await sync_to_async(predictions_count_qs.count, thread_sensitive=True)()

annotations_count_qs = Annotation.objects.filter(task_id__in=queryset, was_cancelled=False)
self.total_annotations = await sync_to_async(annotations_count_qs.count, thread_sensitive=True)()
return await sync_to_async(super().paginate_queryset, thread_sensitive=True)(queryset, request, view)

def sync_paginate_queryset(self, queryset, request, view=None):
self.total_predictions = Prediction.objects.filter(task_id__in=queryset).count()
self.total_annotations = Annotation.objects.filter(task_id__in=queryset, was_cancelled=False).count()
return super().paginate_queryset(queryset, request, view)

def paginate_queryset(self, queryset, request, view=None):
if flag_set('fflag_fix_back_leap_24_tasks_api_optimization_05092023_short'):
return self.async_paginate_queryset(queryset, request, view)
else:
return self.sync_paginate_queryset(queryset, request, view)

def get_paginated_response(self, data):
return Response(
{
Expand Down Expand Up @@ -244,7 +261,15 @@ def get(self, request):
evaluate_predictions(tasks_for_predictions)
[tasks_by_ids[_id].refresh_from_db() for _id in ids]

serializer = self.task_serializer_class(page, many=True, context=context)
if flag_set('fflag_fix_back_leap_24_tasks_api_optimization_05092023_short'):
serializer = self.task_serializer_class(
page,
many=True,
context=context,
include=get_fields_for_evaluation(prepare_params, request.user, skip_regular=False),
)
else:
serializer = self.task_serializer_class(page, many=True, context=context)
return self.get_paginated_response(serializer.data)
# all tasks
if project.evaluate_predictions_automatically:
Expand Down
11 changes: 6 additions & 5 deletions label_studio/data_manager/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_fields_for_filter_ordering(prepare_params):
return result


def get_fields_for_evaluation(prepare_params, user):
def get_fields_for_evaluation(prepare_params, user, skip_regular=True):
"""Collecting field names to annotate them

:param prepare_params: structure with filters and ordering
Expand Down Expand Up @@ -110,10 +110,11 @@ def get_fields_for_evaluation(prepare_params, user):
result = set(result)

# we don't need to annotate regular model fields, so we skip them
skipped_fields = [field.attname for field in Task._meta.fields]
skipped_fields.append('id')
result = [f for f in result if f not in skipped_fields]
result = [f for f in result if not f.startswith('data.')]
if skip_regular:
skipped_fields = [field.attname for field in Task._meta.fields]
skipped_fields.append('id')
result = [f for f in result if f not in skipped_fields]
result = [f for f in result if not f.startswith('data.')]

return result

Expand Down