Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollback db transaction when background task is interrupted #2019

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions server/tests/test_background_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,33 +342,50 @@ def multiple(num):


def test_task_interrupted(caplog, db_session):
results = []
db_session.execute(
"""
CREATE TABLE IF NOT EXISTS task_to_interrupt_results (
num INT,
inserted_at TIMESTAMPTZ DEFAULT NOW()
)
"""
)
db_session.execute("TRUNCATE TABLE task_to_interrupt_results")

@background_task
def interrupted(num):
nonlocal results
results.append(num)
def task_to_interrupt(num):
db_session.execute(
"INSERT INTO task_to_interrupt_results (num) VALUES (:num)", dict(num=num)
)

task1 = create_background_task(interrupted, dict(num=1), db_session)
create_background_task(interrupted, dict(num=2), db_session)
task1 = create_background_task(task_to_interrupt, dict(num=1), db_session)
create_background_task(task_to_interrupt, dict(num=2), db_session)

# Simulate that the worker got interrupted mid-task
claim_next_task("test_worker", db_session)
db_session.commit()
# Simulate starting the task before interruption
task_to_interrupt(num=1)
reset_task(task1, db_session)
db_session.commit()

run_task(claim_next_task("test_worker", db_session), db_session)
run_task(claim_next_task("test_worker", db_session), db_session)

results = [
num
for num, in db_session.execute(
"SELECT num FROM task_to_interrupt_results ORDER BY inserted_at"
).fetchall()
]
assert results == [1, 2]

assert find_log(
caplog,
logging.INFO,
(
f"TASK_RESET {{'id': '{task1.id}', "
"'task_name': 'interrupted',"
"'task_name': 'task_to_interrupt',"
f" 'payload': {{'num': 1}},"
" 'worker_id': 'test_worker'}"
),
Expand All @@ -378,7 +395,7 @@ def interrupted(num):
logging.INFO,
(
f"TASK_START {{'id': '{task1.id}', "
"'task_name': 'interrupted',"
"'task_name': 'task_to_interrupt',"
f" 'payload': {{'num': 1}},"
" 'worker_id': 'test_worker'}"
),
Expand All @@ -388,7 +405,7 @@ def interrupted(num):
logging.INFO,
(
f"TASK_COMPLETE {{'id': '{task1.id}', "
"'task_name': 'interrupted',"
"'task_name': 'task_to_interrupt',"
f" 'payload': {{'num': 1}},"
" 'worker_id': 'test_worker'}"
),
Expand All @@ -400,14 +417,17 @@ def test_multiple_workers(db_session):
num_tasks = 40
num_workers = 4
expected_results = list(range(num_tasks))
manager = context.Manager()
results = manager.list()

db_session.execute("CREATE TABLE IF NOT EXISTS count_results (num INT)")
db_session.execute("TRUNCATE TABLE count_results")
db_session.commit()

@background_task
def count(num: int):
nonlocal results
def count(db_session, num: int):
time.sleep(random.randint(0, 2) / 10)
results.append(num)
db_session.execute(
"INSERT INTO count_results (num) VALUES (:num)", dict(num=num)
)

# Enqueue tasks
for num in expected_results:
Expand Down Expand Up @@ -454,7 +474,8 @@ def num_completed_tasks():
worker.terminate()

expected_sorted_results = list(range(num_tasks))
results = [
num for num, in db_session.execute("SELECT num FROM count_results").fetchall()
]
# Each task should have run exactly once
assert sorted(results) == expected_sorted_results
# Tasks should not have run in order, since some got interrupted and reset
assert results != expected_sorted_results
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I like what this assertion is trying to test, it relies on non-deterministic timing for the tasks to get executed out of order. I couldn't think of an easy way to make this work deterministically, so to prevent flaky tests I'm removing it.

7 changes: 6 additions & 1 deletion server/worker/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def run_task(task: BackgroundTask, db_session):
logger.info(f"TASK_START {task_log_data(task)}")

task_args = dict(task.payload)
task_parameters = signature(task_handler).parameters
# Inject emit_progress for handlers that want to record task progress
if "emit_progress" in signature(task_handler).parameters:
if "emit_progress" in task_parameters:
task_args["emit_progress"] = emit_progress_for_task(task.id)
# For testing, allow the db_session to be injected into the task handler.
if "db_session" in task_parameters:
task_args["db_session"] = db_session

try:
task_handler(**task_args)
Expand Down Expand Up @@ -149,6 +153,7 @@ def claim_next_task(worker_id: str, db_session) -> Optional[BackgroundTask]:


def reset_task(task: BackgroundTask, db_session):
db_session.rollback()
logger.info(f"TASK_RESET {task_log_data(task)}")
task.worker_id = None
task.started_at = None
Expand Down