Skip to content

Commit

Permalink
Rollback db transaction when background task is interrupted (#2019)
Browse files Browse the repository at this point in the history
Previously, when a task got interrupted, we let the process die
immediately, which would rollback the transaction automatically.

Now, we try to reset the task, which involves clearing the task
started_at timestamp and then commiting the current transaction to save
that change to the db. So in order to ensure that any in-progress
changes from the task itself are not committed alongside those changes,
we need to rollback the current transaction and start a new one.

I updated the tests for task resetting to cover this behavior, which
they previously didn't cover because they didn't use the db.
  • Loading branch information
jonahkagan authored Oct 29, 2024
1 parent f48b678 commit 4615ef9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 17 deletions.
53 changes: 37 additions & 16 deletions server/tests/test_background_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,33 +342,50 @@ def multiple(num):


def test_task_interrupted(caplog, db_session):
results = []
db_session.execute(
"""
CREATE TABLE IF NOT EXISTS task_to_interrupt_results (
num INT,
inserted_at TIMESTAMPTZ DEFAULT NOW()
)
"""
)
db_session.execute("TRUNCATE TABLE task_to_interrupt_results")

@background_task
def interrupted(num):
nonlocal results
results.append(num)
def task_to_interrupt(num):
db_session.execute(
"INSERT INTO task_to_interrupt_results (num) VALUES (:num)", dict(num=num)
)

task1 = create_background_task(interrupted, dict(num=1), db_session)
create_background_task(interrupted, dict(num=2), db_session)
task1 = create_background_task(task_to_interrupt, dict(num=1), db_session)
create_background_task(task_to_interrupt, dict(num=2), db_session)

# Simulate that the worker got interrupted mid-task
claim_next_task("test_worker", db_session)
db_session.commit()
# Simulate starting the task before interruption
task_to_interrupt(num=1)
reset_task(task1, db_session)
db_session.commit()

run_task(claim_next_task("test_worker", db_session), db_session)
run_task(claim_next_task("test_worker", db_session), db_session)

results = [
num
for num, in db_session.execute(
"SELECT num FROM task_to_interrupt_results ORDER BY inserted_at"
).fetchall()
]
assert results == [1, 2]

assert find_log(
caplog,
logging.INFO,
(
f"TASK_RESET {{'id': '{task1.id}', "
"'task_name': 'interrupted',"
"'task_name': 'task_to_interrupt',"
f" 'payload': {{'num': 1}},"
" 'worker_id': 'test_worker'}"
),
Expand All @@ -378,7 +395,7 @@ def interrupted(num):
logging.INFO,
(
f"TASK_START {{'id': '{task1.id}', "
"'task_name': 'interrupted',"
"'task_name': 'task_to_interrupt',"
f" 'payload': {{'num': 1}},"
" 'worker_id': 'test_worker'}"
),
Expand All @@ -388,7 +405,7 @@ def interrupted(num):
logging.INFO,
(
f"TASK_COMPLETE {{'id': '{task1.id}', "
"'task_name': 'interrupted',"
"'task_name': 'task_to_interrupt',"
f" 'payload': {{'num': 1}},"
" 'worker_id': 'test_worker'}"
),
Expand All @@ -400,14 +417,17 @@ def test_multiple_workers(db_session):
num_tasks = 40
num_workers = 4
expected_results = list(range(num_tasks))
manager = context.Manager()
results = manager.list()

db_session.execute("CREATE TABLE IF NOT EXISTS count_results (num INT)")
db_session.execute("TRUNCATE TABLE count_results")
db_session.commit()

@background_task
def count(num: int):
nonlocal results
def count(db_session, num: int):
time.sleep(random.randint(0, 2) / 10)
results.append(num)
db_session.execute(
"INSERT INTO count_results (num) VALUES (:num)", dict(num=num)
)

# Enqueue tasks
for num in expected_results:
Expand Down Expand Up @@ -454,7 +474,8 @@ def num_completed_tasks():
worker.terminate()

expected_sorted_results = list(range(num_tasks))
results = [
num for num, in db_session.execute("SELECT num FROM count_results").fetchall()
]
# Each task should have run exactly once
assert sorted(results) == expected_sorted_results
# Tasks should not have run in order, since some got interrupted and reset
assert results != expected_sorted_results
7 changes: 6 additions & 1 deletion server/worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def run_task(task: BackgroundTask, db_session):
logger.info(f"TASK_START {task_log_data(task)}")

task_args = dict(task.payload)
task_parameters = signature(task_handler).parameters
# Inject emit_progress for handlers that want to record task progress
if "emit_progress" in signature(task_handler).parameters:
if "emit_progress" in task_parameters:
task_args["emit_progress"] = emit_progress_for_task(task.id)
# For testing, allow the db_session to be injected into the task handler.
if "db_session" in task_parameters:
task_args["db_session"] = db_session

try:
task_handler(**task_args)
Expand Down Expand Up @@ -149,6 +153,7 @@ def claim_next_task(worker_id: str, db_session) -> Optional[BackgroundTask]:


def reset_task(task: BackgroundTask, db_session):
db_session.rollback()
logger.info(f"TASK_RESET {task_log_data(task)}")
task.worker_id = None
task.started_at = None
Expand Down

0 comments on commit 4615ef9

Please sign in to comment.