diff --git a/server/tests/test_background_tasks.py b/server/tests/test_background_tasks.py index 1d9b26106..80aaeeff8 100644 --- a/server/tests/test_background_tasks.py +++ b/server/tests/test_background_tasks.py @@ -342,25 +342,42 @@ def multiple(num): def test_task_interrupted(caplog, db_session): - results = [] + db_session.execute( + """ + CREATE TABLE IF NOT EXISTS task_to_interrupt_results ( + num INT, + inserted_at TIMESTAMPTZ DEFAULT NOW() + ) + """ + ) + db_session.execute("TRUNCATE TABLE task_to_interrupt_results") @background_task - def interrupted(num): - nonlocal results - results.append(num) + def task_to_interrupt(num): + db_session.execute( + "INSERT INTO task_to_interrupt_results (num) VALUES (:num)", dict(num=num) + ) - task1 = create_background_task(interrupted, dict(num=1), db_session) - create_background_task(interrupted, dict(num=2), db_session) + task1 = create_background_task(task_to_interrupt, dict(num=1), db_session) + create_background_task(task_to_interrupt, dict(num=2), db_session) # Simulate that the worker got interrupted mid-task claim_next_task("test_worker", db_session) db_session.commit() + # Simulate starting the task before interruption + task_to_interrupt(num=1) reset_task(task1, db_session) db_session.commit() run_task(claim_next_task("test_worker", db_session), db_session) run_task(claim_next_task("test_worker", db_session), db_session) + results = [ + num + for num, in db_session.execute( + "SELECT num FROM task_to_interrupt_results ORDER BY inserted_at" + ).fetchall() + ] assert results == [1, 2] assert find_log( @@ -368,7 +385,7 @@ def interrupted(num): logging.INFO, ( f"TASK_RESET {{'id': '{task1.id}', " - "'task_name': 'interrupted'," + "'task_name': 'task_to_interrupt'," f" 'payload': {{'num': 1}}," " 'worker_id': 'test_worker'}" ), @@ -378,7 +395,7 @@ def interrupted(num): logging.INFO, ( f"TASK_START {{'id': '{task1.id}', " - "'task_name': 'interrupted'," + "'task_name': 'task_to_interrupt'," f" 'payload': {{'num': 1}}," " 'worker_id': 'test_worker'}" ), @@ -388,7 +405,7 @@ def interrupted(num): logging.INFO, ( f"TASK_COMPLETE {{'id': '{task1.id}', " - "'task_name': 'interrupted'," + "'task_name': 'task_to_interrupt'," f" 'payload': {{'num': 1}}," " 'worker_id': 'test_worker'}" ), @@ -400,14 +417,17 @@ def test_multiple_workers(db_session): num_tasks = 40 num_workers = 4 expected_results = list(range(num_tasks)) - manager = context.Manager() - results = manager.list() + + db_session.execute("CREATE TABLE IF NOT EXISTS count_results (num INT)") + db_session.execute("TRUNCATE TABLE count_results") + db_session.commit() @background_task - def count(num: int): - nonlocal results + def count(db_session, num: int): time.sleep(random.randint(0, 2) / 10) - results.append(num) + db_session.execute( + "INSERT INTO count_results (num) VALUES (:num)", dict(num=num) + ) # Enqueue tasks for num in expected_results: @@ -454,7 +474,8 @@ def num_completed_tasks(): worker.terminate() expected_sorted_results = list(range(num_tasks)) + results = [ + num for num, in db_session.execute("SELECT num FROM count_results").fetchall() + ] # Each task should have run exactly once assert sorted(results) == expected_sorted_results - # Tasks should not have run in order, since some got interrupted and reset - assert results != expected_sorted_results diff --git a/server/worker/tasks.py b/server/worker/tasks.py index 306a38afd..b392d3538 100644 --- a/server/worker/tasks.py +++ b/server/worker/tasks.py @@ -91,9 +91,13 @@ def run_task(task: BackgroundTask, db_session): logger.info(f"TASK_START {task_log_data(task)}") task_args = dict(task.payload) + task_parameters = signature(task_handler).parameters # Inject emit_progress for handlers that want to record task progress - if "emit_progress" in signature(task_handler).parameters: + if "emit_progress" in task_parameters: task_args["emit_progress"] = emit_progress_for_task(task.id) + # For testing, allow the db_session to be injected into the task handler. + if "db_session" in task_parameters: + task_args["db_session"] = db_session try: task_handler(**task_args) @@ -149,6 +153,7 @@ def claim_next_task(worker_id: str, db_session) -> Optional[BackgroundTask]: def reset_task(task: BackgroundTask, db_session): + db_session.rollback() logger.info(f"TASK_RESET {task_log_data(task)}") task.worker_id = None task.started_at = None