Skip to content

Commit

Permalink
Abort jobs on deletion or expiration
Browse files Browse the repository at this point in the history
Refactor expiration of jobs to log each job being expired and to
abort the job if it is running. Also abort jobs if they are running
when the job is deleted, and add a test for deletion of jobs via the
POST method rather than the DELETE method.
  • Loading branch information
rra committed Jul 15, 2024
1 parent 5621225 commit 2fc9038
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 26 deletions.
3 changes: 3 additions & 0 deletions changelog.d/20240715_164613_rra_DM_45138.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### New features

- Abort jobs on deletion or expiration if they are pending, queued, or executing.
3 changes: 3 additions & 0 deletions src/vocutouts/uws/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class UWSJobDescription:
job_id: str
"""Unique identifier of the job."""

message_id: str | None
"""Internal message identifier for the work queuing system."""

owner: str
"""Identity of the owner of the job."""

Expand Down
41 changes: 36 additions & 5 deletions src/vocutouts/uws/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,47 @@ async def create(
async def delete(self, user: str, job_id: str) -> None:
"""Delete a job.
The UWS standard says that deleting a job should stop the in-progress
work, but arq, although it supports job cancellation, cannot cancel
sync jobs. Settle for deleting the database entry, which will cause
the task to throw away the results when it finishes.
If the job is in an active phase, cancel it before deleting it.
Parameters
----------
user
Owner of job.
job_id
Identifier of job.
"""
job = await self._storage.get(job_id)
if job.owner != user:
raise PermissionDeniedError(f"Access to job {job_id} denied")
logger = self._logger.bind(user=user, job_id=job_id)
if job.phase in ACTIVE_PHASES and job.message_id:
try:
await self._arq.abort_job(job.message_id)
except Exception as e:
logger.warning("Unable to abort job", error=str(e))
await self._storage.delete(job_id)
self._logger.info("Deleted job", user=user, job_id=job_id)
logger.info("Deleted job")

async def delete_expired(self) -> None:
"""Delete all expired jobs.
A job is expired if it has passed its destruction time. If the job is
in an active phase, cancel it before deleting it.
"""
jobs = await self._storage.list_expired()
if jobs:
self._logger.info(f"Deleting {len(jobs)} expired jobs")
for job in jobs:
if job.phase in ACTIVE_PHASES and job.message_id:
try:
await self._arq.abort_job(job.message_id)
except Exception as e:
self._logger.warning(
"Unable to abort expired job", error=str(e)
)
await self._storage.delete(job.job_id)
self._logger.info("Deleted expired job")
self._logger.info(f"Finished deleting {len(jobs)} expired jobs")

async def get(
self,
Expand Down
26 changes: 23 additions & 3 deletions src/vocutouts/uws/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,30 @@ async def get(self, job_id: str) -> UWSJob:
job = await self._get_job(job_id)
return _convert_job(job)

async def delete_expired(self) -> None:
async def list_expired(self) -> list[UWSJobDescription]:
"""Delete all jobs that have passed their destruction time."""
now = datetime_to_db(current_datetime())
stmt = delete(SQLJob).where(SQLJob.destruction_time <= now)
stmt = select(
SQLJob.id,
SQLJob.message_id,
SQLJob.owner,
SQLJob.phase,
SQLJob.run_id,
SQLJob.creation_time,
).where(SQLJob.destruction_time <= now)
async with self._session.begin():
await self._session.execute(stmt)
jobs = await self._session.execute(stmt)
return [
UWSJobDescription(
job_id=str(j.id),
message_id=j.message_id,
owner=j.owner,
phase=j.phase,
run_id=j.run_id,
creation_time=datetime_from_db(j.creation_time),
)
for j in jobs.all()
]

async def list_jobs(
self,
Expand Down Expand Up @@ -246,6 +264,7 @@ async def list_jobs(
"""
stmt = select(
SQLJob.id,
SQLJob.message_id,
SQLJob.owner,
SQLJob.phase,
SQLJob.run_id,
Expand All @@ -263,6 +282,7 @@ async def list_jobs(
return [
UWSJobDescription(
job_id=str(j.id),
message_id=j.message_id,
owner=j.owner,
phase=j.phase,
run_id=j.run_id,
Expand Down
30 changes: 17 additions & 13 deletions src/vocutouts/uws/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .constants import JOB_RESULT_TIMEOUT
from .exceptions import TaskError, UnknownJobError
from .models import UWSJob
from .service import JobService
from .storage import JobStore
from .uwsworker import WorkerError, WorkerTransientError

Expand Down Expand Up @@ -64,13 +65,26 @@ async def create_uws_worker_context(
Keys to add to the ``ctx`` dictionary.
"""
logger = logger.bind(worker_instance=uuid.uuid4().hex)

# The queue from which to retrieve results is the main work queue,
# which uses the default arq queue name. Note that this is not the
# separate UWS queue this worker is running against.
if config.arq_mode == ArqMode.production:
settings = config.arq_redis_settings
arq: ArqQueue = await RedisArqQueue.initialize(settings)
else:
arq = MockArqQueue()

engine = create_database_engine(
config.database_url,
config.database_password,
isolation_level="REPEATABLE READ",
)
session = await create_async_session(engine, logger)
storage = JobStore(session)
service = JobService(
config=config, arq_queue=arq, storage=storage, logger=logger
)
slack = None
if config.slack_webhook:
slack = SlackWebhookClient(
Expand All @@ -79,19 +93,11 @@ async def create_uws_worker_context(
logger,
)

# The queue from which to retrieve results is the main work queue,
# which uses the default arq queue name. Note that this is not the
# separate UWS queue this worker is running against.
if config.arq_mode == ArqMode.production:
settings = config.arq_redis_settings
arq: ArqQueue = await RedisArqQueue.initialize(settings)
else:
arq = MockArqQueue()

logger.info("Worker startup complete")
return {
"arq": arq,
"logger": logger,
"service": service,
"session": session,
"slack": slack,
"storage": storage,
Expand Down Expand Up @@ -128,17 +134,15 @@ async def uws_expire_jobs(ctx: dict[Any, Any]) -> None:
ctx
arq context.
"""
logger: BoundLogger = ctx["logger"].bind(task="expire_jobs")
slack: SlackWebhookClient | None = ctx["slack"]
storage: JobStore = ctx["storage"]
service: JobService = ctx["service"]

try:
await storage.delete_expired()
await service.delete_expired()
except Exception as e:
if slack:
await slack.post_uncaught_exception(e)
raise
logger.info("Deleted expired jobs")


async def uws_job_started(
Expand Down
19 changes: 18 additions & 1 deletion tests/support/uws.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from arq.connections import RedisSettings
from fastapi import Form, Query
from pydantic import BaseModel, SecretStr
from safir.arq import ArqMode, JobMetadata, MockArqQueue
from safir.arq import ArqMode, JobMetadata, JobResult, MockArqQueue

from vocutouts.uws.config import ParametersModel, UWSConfig, UWSRoute
from vocutouts.uws.dependencies import UWSFactory
Expand Down Expand Up @@ -135,6 +135,23 @@ async def get_job_metadata(
assert job.message_id
return await self._arq.get_job_metadata(job.message_id)

async def get_job_result(self, username: str, job_id: str) -> JobResult:
"""Get the arq job result for a job.
Parameters
----------
job_id
UWS job ID.
Returns
-------
JobMetadata
arq job metadata.
"""
job = await self._service.get(username, job_id)
assert job.message_id
return await self._arq.get_job_result(job.message_id)

async def mark_in_progress(
self, username: str, job_id: str, *, delay: float | None = None
) -> UWSJob:
Expand Down
32 changes: 28 additions & 4 deletions tests/uws/job_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

import asyncio
from datetime import UTC, datetime, timedelta
from unittest.mock import ANY

Expand Down Expand Up @@ -276,6 +277,7 @@ async def test_job_run(
async def test_job_abort(
client: AsyncClient,
runner: MockJobRunner,
arq_queue: MockArqQueue,
uws_factory: UWSFactory,
uws_config: UWSConfig,
) -> None:
Expand Down Expand Up @@ -334,6 +336,32 @@ async def test_job_abort(
isodatetime(job.end_time),
isodatetime(job.creation_time + timedelta(seconds=24 * 60 * 60)),
)
job_result = await runner.get_job_result("user", "2")
assert not job_result.success
assert isinstance(job_result.result, asyncio.CancelledError)

# Deleting a job should also abort it. Also test a weird capitalization of
# the phase parameter and the POST form of the delete support.
r = await client.post(
"/test/jobs",
headers={"X-Auth-Request-User": "user"},
data={"runid": "some-run-id", "name": "Jane", "PHAse": "RUN"},
)
assert r.status_code == 303
assert r.headers["Location"] == "https://example.com/test/jobs/3"
await runner.mark_in_progress("user", "3")
job = await job_service.get("user", "3")
r = await client.post(
"/test/jobs/3",
headers={"X-Auth-Request-User": "user"},
data={"action": "DELETE"},
)
assert r.status_code == 303
assert r.headers["Location"] == "https://example.com/test/jobs"
assert job.message_id
job_result = await arq_queue.get_job_result(job.message_id)
assert not job_result.success
assert isinstance(job_result.result, asyncio.CancelledError)


@pytest.mark.asyncio
Expand Down Expand Up @@ -561,16 +589,12 @@ async def test_presigned_url(
uws_factory: UWSFactory,
uws_config: UWSConfig,
) -> None:
job_service = uws_factory.create_job_service()

# Create the job.
r = await client.post(
"/test/jobs?phase=RUN",
headers={"X-Auth-Request-User": "user"},
data={"runid": "some-run-id", "name": "Jane"},
)
assert r.status_code == 303
job = await job_service.get("user", "1")
await runner.mark_in_progress("user", "1")

# Tell the queue the job is finished, with an https URL.
Expand Down

0 comments on commit 2fc9038

Please sign in to comment.