Skip to content

Commit

Permalink
Merge pull request #4268 from ozer550/fix_celery_race_condition
Browse files Browse the repository at this point in the history
Resolve Celery TaskObject Race Condition
  • Loading branch information
bjester authored Oct 30, 2023
2 parents 92454c4 + 3f9a853 commit aef05d0
Show file tree
Hide file tree
Showing 14 changed files with 272 additions and 158 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Generated by Django 3.2.19 on 2023-09-14 05:08
import django.core.validators
import django.db.models.deletion
from celery import states
from django.conf import settings
from django.db import migrations
from django.db import models

def transfer_data(apps, schema_editor):
CustomTaskMetadata = apps.get_model('contentcuration', 'CustomTaskMetadata')
TaskResult = apps.get_model('django_celery_results', 'taskresult')

old_task_results = TaskResult.objects.filter(status__in=states.UNREADY_STATES)

for old_task_result in old_task_results:
CustomTaskMetadata.objects.create(
task_id=old_task_result.task_id,
user=old_task_result.user,
channel_id=old_task_result.channel_id,
progress=old_task_result.progress,
signature=old_task_result.signature,
)

class Migration(migrations.Migration):

dependencies = [
('contentcuration', '0144_soft_delete_user'),
]

operations = [
migrations.CreateModel(
name='CustomTaskMetadata',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('task_id', models.CharField(max_length=255, unique=True)),
('channel_id', models.UUIDField(blank=True, db_index=True, null=True)),
('progress', models.IntegerField(blank=True, null=True, validators=[django.core.validators.MinValueValidator(0), django.core.validators.MaxValueValidator(100)])),
('signature', models.CharField(max_length=32, null=True)),
('date_created', models.DateTimeField(auto_now_add=True, help_text='Datetime field when the custom_metadata for task was created in UTC', verbose_name='Created DateTime')),
('user', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tasks', to=settings.AUTH_USER_MODEL)),
],
),
migrations.AddIndex(
model_name='customtaskmetadata',
index=models.Index(fields=['signature'], name='task_result_signature'),
),
migrations.RunPython(transfer_data),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Generated by Django 3.2.19 on 2023-09-14 10:42
from django.db import migrations

class Migration(migrations.Migration):

replaces = [('django_celery_results', '0145_custom_task_metadata'),]

def __init__(self, name, app_label):
super(Migration, self).__init__(name, 'django_celery_results')

dependencies = [
('contentcuration', '0145_custom_task_metadata'),
('contentcuration', '0141_add_task_signature'),
]

operations = [
migrations.RemoveField(
model_name='taskresult',
name='channel_id',
),
migrations.RemoveField(
model_name='taskresult',
name='progress',
),
migrations.RemoveField(
model_name='taskresult',
name='user',
),
migrations.RemoveField(
model_name='taskresult',
name='signature',
),
migrations.RemoveIndex(
model_name='taskresult',
name='task_result_signature_idx',
),
]
55 changes: 12 additions & 43 deletions contentcuration/contentcuration/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from datetime import datetime

import pytz
from celery import states as celery_states
from django.conf import settings
from django.contrib.auth.base_user import AbstractBaseUser
from django.contrib.auth.base_user import BaseUserManager
Expand Down Expand Up @@ -46,7 +45,6 @@
from django.dispatch import receiver
from django.utils import timezone
from django.utils.translation import gettext as _
from django_celery_results.models import TaskResult
from django_cte import With
from le_utils import proquint
from le_utils.constants import content_kinds
Expand Down Expand Up @@ -2566,58 +2564,29 @@ def serialize_to_change_dict(self):
return self.serialize(self)


class TaskResultCustom(object):
"""
Custom fields to add to django_celery_results's TaskResult model
If adding fields to this class, run `makemigrations` then move the generated migration from the
`django_celery_results` app to the `contentcuration` app and override the constructor to change
the app_label. See `0141_add_task_signature` for an example
"""
class CustomTaskMetadata(models.Model):
# Task_id for reference
task_id = models.CharField(
max_length=255, # Adjust the max_length as needed
unique=True,
)
# user shouldn't be null, but in order to append the field, this needs to be allowed
user = models.ForeignKey(settings.AUTH_USER_MODEL, related_name="tasks", on_delete=models.CASCADE, null=True)
channel_id = DjangoUUIDField(db_index=True, null=True, blank=True)
progress = models.IntegerField(null=True, blank=True, validators=[MinValueValidator(0), MaxValueValidator(100)])
# a hash of the task name and kwargs for identifying repeat tasks
signature = models.CharField(null=True, blank=False, max_length=32)

super_as_dict = TaskResult.as_dict

def as_dict(self):
"""
:return: A dictionary representation
"""
super_dict = self.super_as_dict()
super_dict.update(
user_id=self.user_id,
channel_id=self.channel_id,
progress=self.progress,
)
return super_dict

@classmethod
def contribute_to_class(cls, model_class=TaskResult):
"""
Adds fields to model, by default TaskResult
:param model_class: TaskResult model
"""
for field in dir(cls):
if not field.startswith("_") and field not in ('contribute_to_class', 'Meta'):
model_class.add_to_class(field, getattr(cls, field))

# manually add Meta afterwards
setattr(model_class._meta, 'indexes', getattr(model_class._meta, 'indexes', []) + cls.Meta.indexes)
date_created = models.DateTimeField(
auto_now_add=True,
verbose_name=_('Created DateTime'),
help_text=_('Datetime field when the custom_metadata for task was created in UTC')
)

class Meta:
indexes = [
# add index that matches query usage for signature
models.Index(
fields=['signature'],
name='task_result_signature_idx',
condition=Q(status__in=celery_states.UNREADY_STATES),
name='task_result_signature',
),
]


# trigger class contributions immediately
TaskResultCustom.contribute_to_class()
2 changes: 1 addition & 1 deletion contentcuration/contentcuration/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import mock
from celery import states
from django_celery_results.models import TaskResult
from search.models import ContentNodeFullTextSearch

from contentcuration.models import ContentNode
from contentcuration.models import TaskResult


def clear_tasks(except_task_id=None):
Expand Down
64 changes: 31 additions & 33 deletions contentcuration/contentcuration/tests/test_asynctask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import

import threading
import time
import uuid

import pytest
Expand All @@ -9,12 +10,11 @@
from celery.utils.log import get_task_logger
from django.core.management import call_command
from django.test import TransactionTestCase
from django_celery_results.models import TaskResult

from . import testdata
from .helpers import clear_tasks
from contentcuration.celery import app
from contentcuration.models import TaskResult


logger = get_task_logger(__name__)

Expand All @@ -36,6 +36,22 @@ def test_task(self, **kwargs):
return 42


# Create a Task that takes a bit longer to be executed
@app.task(bind=True, name="test_task_delayed")
def test_task_delayed(self, delay_seconds=100, **kwargs):
"""
This is a mock task that takes a bit longer to execute
so that revoke function can be called succesfully without test being SUCCESS before hand,
to be used ONLY for unit-testing various pieces of the
async task API.
:return: The number 42
"""
time.sleep(delay_seconds)
logger.info("Request ID = {}".format(self.request.id))
assert TaskResult.objects.filter(task_id=self.request.id).count() == 1
return 42


@app.task(name="error_test_task")
def error_test_task(**kwargs):
"""
Expand Down Expand Up @@ -88,11 +104,8 @@ def _celery_task_worker():
])


def _celery_progress_monitor(thread_event):
def _on_iteration(receiver):
if thread_event.is_set():
receiver.should_stop = True
app.events.monitor_progress(on_iteration=_on_iteration)
def _return_celery_task_object(task_id):
return TaskResult.objects.get(task_id=task_id)


class AsyncTaskTestCase(TransactionTestCase):
Expand All @@ -108,11 +121,6 @@ class AsyncTaskTestCase(TransactionTestCase):
@classmethod
def setUpClass(cls):
super(AsyncTaskTestCase, cls).setUpClass()
# start progress monitor in separate thread
cls.monitor_thread_event = threading.Event()
cls.monitor_thread = threading.Thread(target=_celery_progress_monitor, args=(cls.monitor_thread_event,))
cls.monitor_thread.start()

# start celery worker in separate thread
cls.worker_thread = threading.Thread(target=_celery_task_worker)
cls.worker_thread.start()
Expand All @@ -122,8 +130,6 @@ def tearDownClass(cls):
super(AsyncTaskTestCase, cls).tearDownClass()
# tell the work thread to stop through the celery control API
if cls.worker_thread:
cls.monitor_thread_event.set()
cls.monitor_thread.join(5)
app.control.shutdown()
cls.worker_thread.join(5)

Expand Down Expand Up @@ -152,15 +158,14 @@ def test_asynctask_reports_success(self):
contains the return value of the task.
"""
async_result = test_task.enqueue(self.user)

result = self._wait_for(async_result)
task_result = async_result.get_model()
celery_task_result = TaskResult.objects.get(task_id=task_result.task_id)
self.assertEqual(task_result.user, self.user)

self.assertEqual(result, 42)
task_result.refresh_from_db()
self.assertEqual(async_result.result, 42)
self.assertEqual(task_result.task_name, "test_task")
self.assertEqual(celery_task_result.task_name, "test_task")
self.assertEqual(async_result.status, states.SUCCESS)
self.assertEqual(TaskResult.objects.get(task_id=async_result.id).result, "42")
self.assertEqual(TaskResult.objects.get(task_id=async_result.id).status, states.SUCCESS)
Expand All @@ -177,11 +182,12 @@ def test_asynctask_reports_error(self):
self._wait_for(async_result)

task_result = async_result.get_model()
self.assertEqual(task_result.status, states.FAILURE)
self.assertIsNotNone(task_result.traceback)
celery_task_result = _return_celery_task_object(task_result.task_id)
self.assertEqual(celery_task_result.status, states.FAILURE)
self.assertIsNotNone(celery_task_result.traceback)

self.assertIn(
"I'm sorry Dave, I'm afraid I can't do that.", task_result.result
"I'm sorry Dave, I'm afraid I can't do that.", celery_task_result.result
)

def test_only_create_async_task_creates_task_entry(self):
Expand All @@ -194,17 +200,6 @@ def test_only_create_async_task_creates_task_entry(self):
self.assertEquals(result, 42)
self.assertEquals(TaskResult.objects.filter(task_id=async_result.task_id).count(), 0)

def test_enqueue_task_adds_result_with_necessary_info(self):
async_result = test_task.enqueue(self.user, is_test=True)
try:
task_result = TaskResult.objects.get(task_id=async_result.task_id)
except TaskResult.DoesNotExist:
self.fail('Missing task result')

self.assertEqual(task_result.task_name, test_task.name)
_, _, encoded_kwargs = test_task.backend.encode_content(dict(is_test=True))
self.assertEqual(task_result.task_kwargs, encoded_kwargs)

@pytest.mark.skip(reason="This test is flaky on Github Actions")
def test_fetch_or_enqueue_task(self):
expected_task = test_task.enqueue(self.user, is_test=True)
Expand Down Expand Up @@ -258,8 +253,11 @@ def test_requeue_task(self):

def test_revoke_task(self):
channel_id = uuid.uuid4()
async_result = test_task.enqueue(self.user, channel_id=channel_id)
test_task.revoke(channel_id=channel_id)
async_result = test_task_delayed.enqueue(self.user, channel_id=channel_id)
# A bit delay to let the task object be saved async,
# This delay is relatively small and hopefully wont cause any issues in the real time
time.sleep(0.5)
test_task_delayed.revoke(channel_id=channel_id)

# this should raise an exception, even though revoked, because the task is in ready state but not success
with self.assertRaises(Exception):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from django.core.files.base import ContentFile
from django.core.files.storage import default_storage
from django.urls import reverse_lazy
from django_celery_results.models import TaskResult
from le_utils.constants import content_kinds
from le_utils.constants import file_formats
from le_utils.constants import format_presets
Expand All @@ -19,7 +20,6 @@
from contentcuration.constants import user_history
from contentcuration.models import ContentNode
from contentcuration.models import File
from contentcuration.models import TaskResult
from contentcuration.models import UserHistory
from contentcuration.tests.base import BaseAPITestCase
from contentcuration.tests.base import StudioTestCase
Expand Down Expand Up @@ -384,10 +384,9 @@ def test_clean_up(self):
class CleanupTaskTestCase(StudioTestCase):

def setUp(self):
user = self.admin_user
self.pruned_task = TaskResult.objects.create(task_id=uuid.uuid4().hex, status=states.SUCCESS, task_name="pruned_task", user_id=user.id)
self.failed_task = TaskResult.objects.create(task_id=uuid.uuid4().hex, status=states.FAILURE, task_name="failed_task", user_id=user.id)
self.recent_task = TaskResult.objects.create(task_id=uuid.uuid4().hex, status=states.SUCCESS, task_name="recent_task", user_id=user.id)
self.pruned_task = TaskResult.objects.create(task_id=uuid.uuid4().hex, status=states.SUCCESS, task_name="pruned_task")
self.failed_task = TaskResult.objects.create(task_id=uuid.uuid4().hex, status=states.FAILURE, task_name="failed_task")
self.recent_task = TaskResult.objects.create(task_id=uuid.uuid4().hex, status=states.SUCCESS, task_name="recent_task")

# `date_done` uses `auto_now`, so manually set it
done = datetime.now() - timedelta(days=8)
Expand Down
Loading

0 comments on commit aef05d0

Please sign in to comment.