From a5f7619f72ea7c01d12642b83eb1b1463a63099e Mon Sep 17 00:00:00 2001 From: Yann Normand Date: Mon, 26 Aug 2024 00:15:19 +1000 Subject: [PATCH] add support for abort notification --- docs/discussions.md | 11 ++- docs/howto/advanced/cancellation.md | 29 ++----- procrastinate/app.py | 31 ++++--- procrastinate/connector.py | 13 ++- .../contrib/aiopg/aiopg_connector.py | 19 ++-- .../contrib/django/django_connector.py | 3 +- .../migrations/0032_cancel_notification.py | 15 ++++ procrastinate/contrib/django/models.py | 1 + procrastinate/job_context.py | 14 +-- procrastinate/jobs.py | 20 ++++- procrastinate/manager.py | 84 ++++++++---------- procrastinate/psycopg_connector.py | 17 ++-- .../03.00.00_01_cancel_notification.sql | 87 +++++++++++++++++++ procrastinate/sql/queries.sql | 14 +-- procrastinate/sql/schema.sql | 64 +++++++++++--- procrastinate/testing.py | 56 +++++++++--- procrastinate/utils.py | 2 +- procrastinate/worker.py | 68 ++++++++++++--- tests/acceptance/test_async.py | 73 +++++++++++----- tests/conftest.py | 4 +- tests/integration/contrib/aiopg/conftest.py | 2 +- .../contrib/aiopg/test_aiopg_connector.py | 42 +++++++-- .../integration/contrib/django/test_models.py | 1 + tests/integration/test_psycopg_connector.py | 36 ++++++-- tests/unit/test_app.py | 2 +- tests/unit/test_builtin_tasks.py | 2 +- tests/unit/test_connector.py | 2 +- tests/unit/test_job_context.py | 27 ++---- tests/unit/test_jobs.py | 1 + tests/unit/test_manager.py | 8 +- tests/unit/test_testing.py | 9 +- tests/unit/test_worker.py | 43 ++++++--- tests/unit/test_worker_sync.py | 7 +- 33 files changed, 565 insertions(+), 242 deletions(-) create mode 100644 procrastinate/contrib/django/migrations/0032_cancel_notification.py create mode 100644 procrastinate/sql/migrations/03.00.00_01_cancel_notification.sql diff --git a/docs/discussions.md b/docs/discussions.md index b0310c874..6758e5a55 100644 --- a/docs/discussions.md +++ b/docs/discussions.md @@ -201,14 +201,21 @@ many factors to take into account when [sizing your pool](https://wiki.postgresq ### How the `polling_interval` works -Even when the database doesn't notify workers regarding newly deferred jobs, idle -workers still poll the database every now and then, just in case. +Even when the database doesn't notify workers regarding newly deferred jobs, each worker still poll the database every now and then, just in case. There could be previously locked jobs that are now free, or scheduled jobs that have reached the ETA. `polling_interval` is the {py:meth}`App.run_worker` parameter (or the equivalent CLI flag) that sizes this "every now and then". A worker will keep fetching new jobs as long as they have capacity to process them. The polling interval starts from the moment the last attempt to fetch a new job yields no result. + +The `polling_interval` also defines how often the worker will poll the database for jobs to abort. +When `listen_notify=True`, the worker will likely be notified "instantly" of each abort request prior to polling the database. + +However, in the event `listen_notify=False` or if the abort notification was missed, `polling_interval` will represent the maximum delay before the worker reacts to an abort request. + +Note that the worker will not poll the database for jobs to be aborted if it is idle (i.e. it has no running job). + :::{note} The polling interval was previously called `timeout` in pre-v3 versions of Procrastinate. It was renamed to `polling_interval` for clarity. ::: diff --git a/docs/howto/advanced/cancellation.md b/docs/howto/advanced/cancellation.md index 0743c21e8..d5d924bb7 100644 --- a/docs/howto/advanced/cancellation.md +++ b/docs/howto/advanced/cancellation.md @@ -24,11 +24,10 @@ app.job_manager.cancel_job_by_id(33, delete_job=True) await app.job_manager.cancel_job_by_id_async(33, delete_job=True) ``` -## Mark a currently being processed job for abortion +## Mark a running job for abortion If a worker has not picked up the job yet, the below command behaves like the -command without the `abort` option. But if a job is already in the middle of -being processed, the `abort` option marks this job for abortion (see below +command without the `abort` option. But if a job is already running, the `abort` option marks this job for abortion (see below how to handle this request). ```python @@ -38,10 +37,10 @@ app.job_manager.cancel_job_by_id(33, abort=True) await app.job_manager.cancel_job_by_id_async(33, abort=True) ``` -## Handle a abortion request inside the task +## Handle an abortion request inside the task In our task, we can check (for example, periodically) if the task should be -aborted. If we want to respect that request (we don't have to), we raise a +aborted. If we want to respect that abortion request (we don't have to), we raise a `JobAborted` error. Any message passed to `JobAborted` (e.g. `raise JobAborted("custom message")`) will end up in the logs. @@ -54,24 +53,10 @@ def my_task(context): do_something_expensive() ``` -There is also an async API +Behind the scenes, the worker receives a Postgres notification every time a job is requested to abort, (unless `listen_notify=False`). -```python -@app.task(pass_context=True) -async def my_task(context): - for i in range(100): - if await context.should_abort_async(): - raise exceptions.JobAborted - do_something_expensive() -``` - -:::{warning} -`context.should_abort()` and `context.should_abort_async()` does poll the -database and might flood the database. Ensure you do it only sometimes and -not from too many parallel tasks. -::: +The worker also polls (respecting `polling_interval`) the database for abortion requests, as long as the worker is running at least one job (in the absence of running job, there is nothing to abort). :::{note} -When a task of a job that was requested to be aborted raises an error, the job -is marked as failed (regardless of the retry strategy). +When a job is requested to abort and that job fails, it will not be retried (regardless of the retry strategy). ::: diff --git a/procrastinate/app.py b/procrastinate/app.py index 711006574..3eaaeff73 100644 --- a/procrastinate/app.py +++ b/procrastinate/app.py @@ -270,22 +270,33 @@ async def run_worker_async(self, **kwargs: Unpack[WorkerOptions]) -> None: Name of the worker. Will be passed in the `JobContext` and used in the logs (defaults to ``None`` which will result in the worker named ``worker``). - polling_interval: ``float`` - Indicates the maximum duration (in seconds) the worker waits between - each database job poll. Raising this parameter can lower the rate at which - the worker makes queries to the database for requesting jobs. + polling_interval : ``float`` + Maximum time (in seconds) between database job polls. + + Controls the frequency of database queries for: + - Checking for new jobs to start + - Fetching updates for running jobs + - Checking for abort requests + + When `listen_notify` is True, the polling interval acts as a fallback + mechanism and can reasonably be set to a higher value. + (defaults to 5.0) shutdown_timeout: ``float`` Indicates the maximum duration (in seconds) the worker waits for jobs to complete when requested stop. Jobs that have not been completed by that time are aborted. A value of None corresponds to no timeout. (defaults to None) - listen_notify: ``bool`` - If ``True``, the worker will dedicate a connection from the pool to - listening to database events, notifying of newly available jobs. - If ``False``, the worker will just poll the database periodically - (see ``polling_interval``). (defaults to ``True``) - delete_jobs: ``str`` + listen_notify : ``bool`` + If ``True``, allocates a connection from the pool to + listen for: + - new job availability + - job abort requests + + Provides lower latency for job updates compared to polling alone. + + Note: Worker polls the database regardless of this setting. (defaults to ``True``) + delete_jobs : ``str`` If ``always``, the worker will automatically delete all jobs on completion. If ``successful`` the worker will only delete successful jobs. If ``never``, the worker will keep the jobs in the database. diff --git a/procrastinate/connector.py b/procrastinate/connector.py index 50615d2b4..541a772a9 100644 --- a/procrastinate/connector.py +++ b/procrastinate/connector.py @@ -1,7 +1,6 @@ from __future__ import annotations -import asyncio -from typing import Any, Callable, Iterable +from typing import Any, Awaitable, Callable, Iterable, Protocol from typing_extensions import LiteralString @@ -13,6 +12,10 @@ LISTEN_TIMEOUT = 30.0 +class Notify(Protocol): + def __call__(self, *, channel: str, payload: str) -> Awaitable[None]: ... + + class BaseConnector: json_dumps: Callable | None = None json_loads: Callable | None = None @@ -59,7 +62,9 @@ async def execute_query_all_async( raise exceptions.SyncConnectorConfigurationError async def listen_notify( - self, event: asyncio.Event, channels: Iterable[str] + self, + on_notification: Notify, + channels: Iterable[str], ) -> None: raise exceptions.SyncConnectorConfigurationError @@ -98,6 +103,6 @@ def execute_query_all( return utils.async_to_sync(self.execute_query_all_async, query, **arguments) async def listen_notify( - self, event: asyncio.Event, channels: Iterable[str] + self, on_notification: Notify, channels: Iterable[str] ) -> None: raise NotImplementedError diff --git a/procrastinate/contrib/aiopg/aiopg_connector.py b/procrastinate/contrib/aiopg/aiopg_connector.py index c3ef93976..962ff46fa 100644 --- a/procrastinate/contrib/aiopg/aiopg_connector.py +++ b/procrastinate/contrib/aiopg/aiopg_connector.py @@ -283,7 +283,7 @@ def _make_dynamic_query(self, query: str, **identifiers: str) -> Any: @wrap_exceptions() async def listen_notify( - self, event: asyncio.Event, channels: Iterable[str] + self, on_notification: connector.Notify, channels: Iterable[str] ) -> None: # We need to acquire a dedicated connection, and use the listen # query @@ -304,14 +304,14 @@ async def listen_notify( query=sql.queries["listen_queue"], channel_name=channel_name ), ) - # Initial set() lets caller know that we're ready to listen - event.set() - await self._loop_notify(event=event, connection=connection) + await self._loop_notify( + on_notification=on_notification, connection=connection + ) @wrap_exceptions() async def _loop_notify( self, - event: asyncio.Event, + on_notification: connector.Notify, connection: aiopg.Connection, timeout: float = connector.LISTEN_TIMEOUT, ) -> None: @@ -324,12 +324,15 @@ async def _loop_notify( if connection.closed: return try: - await asyncio.wait_for(connection.notifies.get(), timeout) + notification = await asyncio.wait_for( + connection.notifies.get(), timeout + ) + await on_notification( + channel=notification.channel, payload=notification.payload + ) except asyncio.TimeoutError: continue except psycopg2.Error: # aiopg>=1.3.1 will raise if the connection is closed while # we wait continue - - event.set() diff --git a/procrastinate/contrib/django/django_connector.py b/procrastinate/contrib/django/django_connector.py index 909c7ac8b..8f178c379 100644 --- a/procrastinate/contrib/django/django_connector.py +++ b/procrastinate/contrib/django/django_connector.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import contextlib from typing import ( TYPE_CHECKING, @@ -141,7 +140,7 @@ def execute_query_all( return list(self._dictfetch(cursor)) async def listen_notify( - self, event: asyncio.Event, channels: Iterable[str] + self, on_notification: connector.Notify, channels: Iterable[str] ) -> None: raise NotImplementedError( "listen/notify is not supported with Django connector" diff --git a/procrastinate/contrib/django/migrations/0032_cancel_notification.py b/procrastinate/contrib/django/migrations/0032_cancel_notification.py new file mode 100644 index 000000000..617265857 --- /dev/null +++ b/procrastinate/contrib/django/migrations/0032_cancel_notification.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from django.db import migrations + +from .. import migrations_utils + + +class Migration(migrations.Migration): + operations = [ + migrations_utils.RunProcrastinateSQL( + name="03.00.00_01_cancel_notification.sql" + ), + ] + name = "0032_cancel_notification" + dependencies = [("procrastinate", "0031_add_abort_on_procrastinate_jobs")] diff --git a/procrastinate/contrib/django/models.py b/procrastinate/contrib/django/models.py index 9ae4bd10c..64ce3c319 100644 --- a/procrastinate/contrib/django/models.py +++ b/procrastinate/contrib/django/models.py @@ -99,6 +99,7 @@ def procrastinate_job(self) -> jobs.Job: status=self.status, scheduled_at=self.scheduled_at, attempts=self.attempts, + abort_requested=self.abort_requested, queueing_lock=self.queueing_lock, ) diff --git a/procrastinate/job_context.py b/procrastinate/job_context.py index 38abc4ec3..d03a10304 100644 --- a/procrastinate/job_context.py +++ b/procrastinate/job_context.py @@ -1,7 +1,7 @@ from __future__ import annotations import time -from typing import Any, Iterable +from typing import Any, Callable, Iterable import attr @@ -54,6 +54,8 @@ class JobContext: additional_context: dict = attr.ib(factory=dict) task_result: Any = None + should_abort: Callable[[], bool] + def evolve(self, **update: Any) -> JobContext: return attr.evolve(self, **update) @@ -68,13 +70,3 @@ def job_description(self, current_timestamp: float) -> str: message += f" (started {duration:.3f} s ago)" return message - - def should_abort(self) -> bool: - assert self.job.id - job_id = self.job.id - return self.app.job_manager.get_job_abort_requested(job_id) - - async def should_abort_async(self) -> bool: - assert self.job.id - job_id = self.job.id - return await self.app.job_manager.get_job_abort_requested_async(job_id) diff --git a/procrastinate/jobs.py b/procrastinate/jobs.py index 1127c370f..5c16d67ac 100644 --- a/procrastinate/jobs.py +++ b/procrastinate/jobs.py @@ -4,9 +4,10 @@ import functools import logging from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict, Union import attr +from typing_extensions import Literal from procrastinate import types @@ -22,6 +23,19 @@ cached_property = getattr(functools, "cached_property", property) +class JobInserted(TypedDict): + type: Literal["job_inserted"] + job_id: int + + +class AbortJobRequested(TypedDict): + type: Literal["abort_job_requested"] + job_id: int + + +Notification = Union[JobInserted, AbortJobRequested] + + def check_aware( instance: Job, attribute: attr.Attribute, value: datetime.datetime ) -> None: @@ -82,6 +96,9 @@ class Job: #: Number of times the job has been tried. attempts: int = 0 + # True if the job is requested to abort + abort_requested: bool = False + @classmethod def from_row(cls, row: dict[str, Any]) -> Job: return cls( @@ -95,6 +112,7 @@ def from_row(cls, row: dict[str, Any]) -> Job: scheduled_at=row["scheduled_at"], queue=row["queue_name"], attempts=row["attempts"], + abort_requested=row.get("abort_requested", False), ) def asdict(self) -> dict[str, Any]: diff --git a/procrastinate/manager.py b/procrastinate/manager.py index 7fc360d0e..a80c3bf75 100644 --- a/procrastinate/manager.py +++ b/procrastinate/manager.py @@ -1,9 +1,9 @@ from __future__ import annotations -import asyncio import datetime +import json import logging -from typing import Any, Iterable, NoReturn +from typing import Any, Awaitable, Iterable, NoReturn, Protocol from procrastinate import connector, exceptions, jobs, sql, utils @@ -12,6 +12,12 @@ QUEUEING_LOCK_CONSTRAINT = "procrastinate_jobs_queueing_lock_idx" +class NotificationCallback(Protocol): + def __call__( + self, *, channel: str, notification: jobs.Notification + ) -> Awaitable[None]: ... + + def get_channel_for_queues(queues: Iterable[str] | None = None) -> Iterable[str]: if queues is None: return ["procrastinate_any_queue"] @@ -360,42 +366,6 @@ async def get_job_status_async(self, job_id: int) -> jobs.Status: ) return jobs.Status(result["status"]) - def get_job_abort_requested(self, job_id: int) -> bool: - """ - Check if a job is requested for abortion. - - Parameters - ---------- - job_id : ``int`` - The id of the job to get the abortion request of - - Returns - ------- - ``bool`` - """ - result = self.connector.get_sync_connector().execute_query_one( - query=sql.queries["get_job_abort_requested"], job_id=job_id - ) - return bool(result["abort_requested"]) - - async def get_job_abort_requested_async(self, job_id: int) -> bool: - """ - Check if a job is requested for abortion. - - Parameters - ---------- - job_id : ``int`` - The id of the job to get the abortion request of - - Returns - ------- - ``bool`` - """ - result = await self.connector.execute_query_one_async( - query=sql.queries["get_job_abort_requested"], job_id=job_id - ) - return bool(result["abort_requested"]) - async def retry_job( self, job: jobs.Job, @@ -491,26 +461,39 @@ def retry_job_by_id( ) async def listen_for_jobs( - self, *, event: asyncio.Event, queues: Iterable[str] | None = None + self, + *, + on_notification: NotificationCallback, + queues: Iterable[str] | None = None, ) -> None: """ - Listens to defer operation in the database, and raises the event each time an - defer operation is seen. + Listens to job notifications from the database, and invokes the callback each time an + notification is received. This coroutine either returns ``None`` upon calling if it cannot start listening or does not return and needs to be cancelled to end. Parameters ---------- - event: - This event will be set each time a defer operation occurs - queues: - If ``None``, all defer operations will be considered. If an iterable of + on_notification : ``connector.Notify`` + A coroutine that will be called and awaited every time a notification is received + queues : ``Optional[Iterable[str]]`` + If ``None``, all notification will be considered. If an iterable of queue names is passed, only defer operations on those queues will be considered. Defaults to ``None`` """ + + async def handle_notification(channel: str, payload: str): + notification: jobs.Notification = json.loads(payload) + logger.debug( + f"Received {notification['type']} notification from channel", + extra={channel: channel, payload: payload}, + ) + await on_notification(channel=channel, notification=notification) + await self.connector.listen_notify( - event=event, channels=get_channel_for_queues(queues=queues) + on_notification=handle_notification, + channels=get_channel_for_queues(queues=queues), ) async def check_connection_async(self) -> bool: @@ -836,3 +819,12 @@ def list_locks( } ) return result + + async def list_jobs_to_abort_async(self, queue: str | None = None) -> Iterable[int]: + """ + List ids of running jobs to abort + """ + rows = await self.connector.execute_query_all_async( + query=sql.queries["list_jobs_to_abort"], queue_name=queue + ) + return [row["id"] for row in rows] diff --git a/procrastinate/psycopg_connector.py b/procrastinate/psycopg_connector.py index 90a07fd28..f2e0dd2f1 100644 --- a/procrastinate/psycopg_connector.py +++ b/procrastinate/psycopg_connector.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import contextlib import logging from typing import ( @@ -249,7 +248,7 @@ async def _get_standalone_connection( @wrap_exceptions() async def listen_notify( - self, event: asyncio.Event, channels: Iterable[str] + self, on_notification: connector.Notify, channels: Iterable[str] ) -> None: while True: async with self._get_standalone_connection() as connection: @@ -260,14 +259,14 @@ async def listen_notify( channel_name=channel_name, ), ) - # Initial set() lets caller know that we're ready to listen - event.set() - await self._loop_notify(event=event, connection=connection) + await self._loop_notify( + on_notification=on_notification, connection=connection + ) @wrap_exceptions() async def _loop_notify( self, - event: asyncio.Event, + on_notification: connector.Notify, connection: psycopg.AsyncConnection, timeout: float = connector.LISTEN_TIMEOUT, ) -> None: @@ -275,12 +274,14 @@ async def _loop_notify( while True: try: - async for _ in utils.gen_with_timeout( + async for notification in utils.gen_with_timeout( aiterable=connection.notifies(), timeout=timeout, raise_timeout=False, ): - event.set() + await on_notification( + channel=notification.channel, payload=notification.payload + ) await connection.execute("SELECT 1") except psycopg.OperationalError: diff --git a/procrastinate/sql/migrations/03.00.00_01_cancel_notification.sql b/procrastinate/sql/migrations/03.00.00_01_cancel_notification.sql new file mode 100644 index 000000000..c2925b1f7 --- /dev/null +++ b/procrastinate/sql/migrations/03.00.00_01_cancel_notification.sql @@ -0,0 +1,87 @@ +CREATE OR REPLACE FUNCTION procrastinate_notify_queue_job_inserted() +RETURNS trigger + LANGUAGE plpgsql +AS $$ +DECLARE + payload TEXT; +BEGIN + SELECT json_object('type': 'job_inserted', 'job_id': NEW.id)::text INTO payload; + PERFORM pg_notify('procrastinate_queue#' || NEW.queue_name, payload); + PERFORM pg_notify('procrastinate_any_queue', payload); + RETURN NEW; +END; +$$; + +DROP TRIGGER IF EXISTS procrastinate_jobs_notify_queue ON procrastinate_jobs; + +CREATE TRIGGER procrastinate_jobs_notify_queue_job_inserted + AFTER INSERT ON procrastinate_jobs + FOR EACH ROW WHEN ((new.status = 'todo'::procrastinate_job_status)) + EXECUTE PROCEDURE procrastinate_notify_queue_job_inserted(); + +DROP FUNCTION IF EXISTS procrastinate_notify_queue; + +CREATE OR REPLACE FUNCTION procrastinate_notify_queue_abort_job() +RETURNS trigger + LANGUAGE plpgsql +AS $$ +DECLARE + payload TEXT; +BEGIN + SELECT json_object('type': 'abort_job_requested', 'job_id': NEW.id)::text INTO payload; + PERFORM pg_notify('procrastinate_queue#' || NEW.queue_name, payload); + PERFORM pg_notify('procrastinate_any_queue', payload); + RETURN NEW; +END; +$$; + +CREATE TRIGGER procrastinate_jobs_notify_queue_abort_job + AFTER UPDATE OF abort_requested ON procrastinate_jobs + FOR EACH ROW WHEN ((old.abort_requested = false AND new.abort_requested = true AND new.status = 'doing'::procrastinate_job_status)) + EXECUTE PROCEDURE procrastinate_notify_queue_abort_job(); + +CREATE OR REPLACE FUNCTION procrastinate_retry_job( + job_id bigint, + retry_at timestamp with time zone, + new_priority integer, + new_queue_name character varying, + new_lock character varying +) + RETURNS void + LANGUAGE plpgsql +AS $$ +DECLARE + _job_id bigint; +BEGIN + UPDATE procrastinate_jobs + SET status = CASE + WHEN NOT abort_requested THEN 'todo'::procrastinate_job_status + ELSE 'failed'::procrastinate_job_status + END, + attempts = CASE + WHEN NOT abort_requested THEN attempts + 1 + ELSE attempts + END, + scheduled_at = CASE + WHEN NOT abort_requested THEN retry_at + ELSE scheduled_at + END, + priority = CASE + WHEN NOT abort_requested THEN COALESCE(new_priority, priority) + ELSE priority + END, + queue_name = CASE + WHEN NOT abort_requested THEN COALESCE(new_queue_name, queue_name) + ELSE queue_name + END, + lock = CASE + WHEN NOT abort_requested THEN COALESCE(new_lock, lock) + ELSE lock + END + WHERE id = job_id AND status = 'doing' + RETURNING id INTO _job_id; + IF _job_id IS NULL THEN + RAISE 'Job was not found or not in "doing" status (job id: %)', job_id; + END IF; +END; +$$; diff --git a/procrastinate/sql/queries.sql b/procrastinate/sql/queries.sql index ce4908f7d..fd851f152 100644 --- a/procrastinate/sql/queries.sql +++ b/procrastinate/sql/queries.sql @@ -58,10 +58,6 @@ SELECT procrastinate_cancel_job(%(job_id)s, %(abort)s, %(delete_job)s) AS id; -- Get the status of a job SELECT status FROM procrastinate_jobs WHERE id = %(job_id)s; --- get_job_abort_requested -- --- Check if an abortion of a job was requested -SELECT abort_requested FROM procrastinate_jobs WHERE id = %(job_id)s; - -- retry_job -- -- Retry a job, changing it from "doing" to "todo" SELECT procrastinate_retry_job(%(job_id)s, %(retry_at)s, %(new_priority)s, %(new_queue_name)s, %(new_lock)s); @@ -89,7 +85,8 @@ SELECT id, args, status, scheduled_at, - attempts + attempts, + abort_requested FROM procrastinate_jobs WHERE (%(id)s::bigint IS NULL OR id = %(id)s) AND (%(queue_name)s::varchar IS NULL OR queue_name = %(queue_name)s) @@ -191,3 +188,10 @@ SELECT FROM locks GROUP BY name ORDER BY name; + +-- list_jobs_to_abort -- +-- Get list of running jobs that are requested to be aborted +SELECT id from procrastinate_jobs +WHERE status = 'doing' +AND abort_requested = true +AND (%(queue_name)s::varchar IS NULL OR queue_name = %(queue_name)s) diff --git a/procrastinate/sql/schema.sql b/procrastinate/sql/schema.sql index c75b67ba6..d7ee49cd0 100644 --- a/procrastinate/sql/schema.sql +++ b/procrastinate/sql/schema.sql @@ -269,12 +269,30 @@ DECLARE _job_id bigint; BEGIN UPDATE procrastinate_jobs - SET status = 'todo', - attempts = attempts + 1, - scheduled_at = retry_at, - priority = COALESCE(new_priority, priority), - queue_name = COALESCE(new_queue_name, queue_name), - lock = COALESCE(new_lock, lock) + SET status = CASE + WHEN NOT abort_requested THEN 'todo'::procrastinate_job_status + ELSE 'failed'::procrastinate_job_status + END, + attempts = CASE + WHEN NOT abort_requested THEN attempts + 1 + ELSE attempts + END, + scheduled_at = CASE + WHEN NOT abort_requested THEN retry_at + ELSE scheduled_at + END, + priority = CASE + WHEN NOT abort_requested THEN COALESCE(new_priority, priority) + ELSE priority + END, + queue_name = CASE + WHEN NOT abort_requested THEN COALESCE(new_queue_name, queue_name) + ELSE queue_name + END, + lock = CASE + WHEN NOT abort_requested THEN COALESCE(new_lock, lock) + ELSE lock + END WHERE id = job_id AND status = 'doing' RETURNING id INTO _job_id; IF _job_id IS NULL THEN @@ -283,13 +301,30 @@ BEGIN END; $$; -CREATE FUNCTION procrastinate_notify_queue() +CREATE FUNCTION procrastinate_notify_queue_job_inserted() RETURNS trigger LANGUAGE plpgsql AS $$ +DECLARE + payload TEXT; BEGIN - PERFORM pg_notify('procrastinate_queue#' || NEW.queue_name, NEW.task_name); - PERFORM pg_notify('procrastinate_any_queue', NEW.task_name); + SELECT json_object('type': 'job_inserted', 'job_id': NEW.id)::text INTO payload; + PERFORM pg_notify('procrastinate_queue#' || NEW.queue_name, payload); + PERFORM pg_notify('procrastinate_any_queue', payload); + RETURN NEW; +END; +$$; + +CREATE FUNCTION procrastinate_notify_queue_abort_job() +RETURNS trigger + LANGUAGE plpgsql +AS $$ +DECLARE + payload TEXT; +BEGIN + SELECT json_object('type': 'abort_job_requested', 'job_id': NEW.id)::text INTO payload; + PERFORM pg_notify('procrastinate_queue#' || NEW.queue_name, payload); + PERFORM pg_notify('procrastinate_any_queue', payload); RETURN NEW; END; $$; @@ -382,10 +417,15 @@ $$; -- Triggers -CREATE TRIGGER procrastinate_jobs_notify_queue +CREATE TRIGGER procrastinate_jobs_notify_queue_job_inserted AFTER INSERT ON procrastinate_jobs FOR EACH ROW WHEN ((new.status = 'todo'::procrastinate_job_status)) - EXECUTE PROCEDURE procrastinate_notify_queue(); + EXECUTE PROCEDURE procrastinate_notify_queue_job_inserted(); + +CREATE TRIGGER procrastinate_jobs_notify_queue_abort_job + AFTER UPDATE OF abort_requested ON procrastinate_jobs + FOR EACH ROW WHEN ((old.abort_requested = false AND new.abort_requested = true AND new.status = 'doing'::procrastinate_job_status)) + EXECUTE PROCEDURE procrastinate_notify_queue_abort_job(); CREATE TRIGGER procrastinate_trigger_status_events_update AFTER UPDATE OF status ON procrastinate_jobs @@ -440,7 +480,7 @@ $$; -- procrastinate_finish_job -- the next_scheduled_at argument is kept for compatibility reasons -CREATE OR REPLACE FUNCTION procrastinate_finish_job(job_id integer, end_status procrastinate_job_status, next_scheduled_at timestamp with time zone, delete_job boolean) +CREATE FUNCTION procrastinate_finish_job(job_id integer, end_status procrastinate_job_status, next_scheduled_at timestamp with time zone, delete_job boolean) RETURNS void LANGUAGE plpgsql AS $$ diff --git a/procrastinate/testing.py b/procrastinate/testing.py index db80f4aa6..283474efb 100644 --- a/procrastinate/testing.py +++ b/procrastinate/testing.py @@ -1,12 +1,12 @@ from __future__ import annotations -import asyncio import datetime +import json from collections import Counter from itertools import count from typing import Any, Dict, Iterable -from procrastinate import connector, exceptions, schema, sql, types, utils +from procrastinate import connector, exceptions, jobs, schema, sql, types, utils JobRow = Dict[str, Any] EventRow = Dict[str, Any] @@ -37,7 +37,7 @@ def reset(self) -> None: self.events: dict[int, list[EventRow]] = {} self.job_counter = count(1) self.queries: list[tuple[str, dict[str, Any]]] = [] - self.notify_event: asyncio.Event | None = None + self.on_notification: connector.Notify | None = None self.notify_channels: list[str] = [] self.periodic_defers: dict[tuple[str, str], int] = {} self.table_exists = True @@ -73,9 +73,9 @@ async def execute_query_all_async( return await self.generic_execute(query, "all", **arguments) async def listen_notify( - self, event: asyncio.Event, channels: Iterable[str] + self, on_notification: connector.Notify, channels: Iterable[str] ) -> None: - self.notify_event = event + self.on_notification = on_notification self.notify_channels = list(channels) def open(self, pool: connector.Pool | None = None) -> None: @@ -131,11 +131,14 @@ async def defer_job_one( if scheduled_at: self.events[id].append({"type": "scheduled", "at": scheduled_at}) self.events[id].append({"type": "deferred", "at": utils.utcnow()}) - if self.notify_event: - if "procrastinate_any_queue" in self.notify_channels or ( - f"procrastinate_queue#{queue}" in self.notify_channels - ): - self.notify_event.set() + + await self._notify( + queue, + { + "type": "job_inserted", + "job_id": id, + }, + ) return job_row async def defer_periodic_job_one( @@ -178,6 +181,21 @@ def finished_jobs(self) -> list[JobRow]: if job["status"] in {"failed", "succeeded"} ] + async def _notify(self, queue_name: str, notification: jobs.Notification): + if not self.on_notification: + return + + destination_channels = { + "procrastinate_any_queue", + f"procrastinate_queue#{queue_name}", + } + + for channel in set(self.notify_channels).intersection(destination_channels): + await self.on_notification( + channel=channel, + payload=json.dumps(notification), + ) + async def fetch_job_one(self, queues: Iterable[str] | None) -> dict: # Creating a copy of the iterable so that we can modify it while we iterate @@ -226,6 +244,14 @@ async def cancel_job_one(self, job_id: int, abort: bool, delete_job: bool) -> di if abort: job_row["abort_requested"] = True + await self._notify( + job_row["queue_name"], + { + "type": "abort_job_requested", + "job_id": job_id, + }, + ) + return {"id": job_id} return {"id": None} @@ -233,9 +259,6 @@ async def cancel_job_one(self, job_id: int, abort: bool, delete_job: bool) -> di async def get_job_status_one(self, job_id: int) -> dict: return {"status": self.jobs[job_id]["status"]} - async def get_job_abort_requested_one(self, job_id: int) -> dict: - return {"abort_requested": self.jobs[job_id]["abort_requested"]} - async def retry_job_run( self, job_id: int, @@ -328,6 +351,13 @@ async def list_locks_all(self, **kwargs): result.append({"name": lock, "jobs_count": len(lock_jobs), "stats": stats}) return result + async def list_jobs_to_abort_all(self, queue_name: str | None): + return list( + await self.list_jobs_all( + status="doing", abort_requested=True, queue_name=queue_name + ) + ) + async def set_job_status_run(self, id, status): id = int(id) self.jobs[id]["status"] = status diff --git a/procrastinate/utils.py b/procrastinate/utils.py index 3eac3e770..4a3422ddb 100644 --- a/procrastinate/utils.py +++ b/procrastinate/utils.py @@ -229,7 +229,7 @@ def log_task_exception(task: asyncio.Task, error: BaseException): if error: log_task_exception(task, error=error) else: - logger.debug(f"Cancelled task ${task.get_name()}") + logger.debug(f"Cancelled task {task.get_name()}") async def wait_any(*coros_or_futures: Coroutine | asyncio.Future): diff --git a/procrastinate/worker.py b/procrastinate/worker.py index f9c7bb658..6c4fd2d86 100644 --- a/procrastinate/worker.py +++ b/procrastinate/worker.py @@ -64,11 +64,12 @@ def __init__( self.logger = logger self._loop_task: asyncio.Future | None = None - self._notify_event = asyncio.Event() + self._new_job_event = asyncio.Event() self._running_jobs: dict[asyncio.Task, job_context.JobContext] = {} self._job_semaphore = asyncio.Semaphore(self.concurrency) self._stop_event = asyncio.Event() self.shutdown_timeout = shutdown_timeout + self._job_ids_to_abort = set() def stop(self): if self._stop_event.is_set(): @@ -80,7 +81,7 @@ def stop(self): self._stop_event.set() - async def periodic_deferrer(self): + async def _periodic_deferrer(self): deferrer = periodic.PeriodicDeferrer( registry=self.app.periodic_registry, **self.app.periodic_defaults, @@ -224,12 +225,7 @@ async def ensure_async() -> Callable[..., Awaitable]: except BaseException as e: exc_info = e - assert job.id - abort_requested = await self.app.job_manager.get_job_abort_requested_async( - job_id=job.id - ) - - if not isinstance(e, exceptions.JobAborted) and not abort_requested: + if not isinstance(e, exceptions.JobAborted): job_retry = ( task.get_retry_exception(exception=e, job=job) if task else None ) @@ -265,6 +261,8 @@ async def ensure_async() -> Callable[..., Awaitable]: job=job, status=status, retry_decision=retry_decision ) + self._job_ids_to_abort.discard(job.id) + self.logger.debug( f"Acknowledged job completion {job.call_string}", extra=self._log_extra( @@ -285,11 +283,14 @@ async def _fetch_and_process_jobs(self): finally: if (not job or self._stop_event.is_set()) and acquire_sem_task.done(): self._job_semaphore.release() - self._notify_event.clear() + self._new_job_event.clear() if not job: break + job_id = job.id + assert job_id + context = job_context.JobContext( app=self.app, worker_name=self.worker_name, @@ -299,6 +300,7 @@ async def _fetch_and_process_jobs(self): else {}, job=job, task=self.app.tasks.get(job.task_name), + should_abort=lambda: job_id in self._job_ids_to_abort, ) job_task = asyncio.create_task( self._process_job(context), @@ -338,6 +340,41 @@ async def run(self): pass raise + async def _handle_notification( + self, *, channel: str, notification: jobs.Notification + ): + if notification["type"] == "job_inserted": + self._new_job_event.set() + elif notification["type"] == "abort_job_requested": + self._handle_abort_jobs_requested([notification["job_id"]]) + + async def _poll_jobs_to_abort(self): + while True: + logger.debug( + f"waiting for {self.polling_interval}s before querying jobs to abort" + ) + await asyncio.sleep(self.polling_interval) + if not self._running_jobs: + logger.debug("Not querying jobs to abort because no job is running") + continue + try: + job_ids = await self.app.job_manager.list_jobs_to_abort_async() + self._handle_abort_jobs_requested(job_ids) + except Exception as error: + logger.exception( + f"poll_jobs_to_abort error: {error!r}", + exc_info=error, + extra={ + "action": "poll_jobs_to_abort_error", + }, + ) + # recover from errors and continue polling + + def _handle_abort_jobs_requested(self, job_ids: Iterable[int]): + running_job_ids = {c.job.id for c in self._running_jobs.values() if c.job.id} + self._job_ids_to_abort |= set(job_ids) + self._job_ids_to_abort &= running_job_ids + async def _shutdown(self, side_tasks: list[asyncio.Task]): """ Gracefully shutdown the worker by cancelling side tasks @@ -365,10 +402,13 @@ async def _shutdown(self, side_tasks: list[asyncio.Task]): def _start_side_tasks(self) -> list[asyncio.Task]: """Start side tasks such as periodic deferrer and notification listener""" - side_tasks = [asyncio.create_task(self.periodic_deferrer(), name="deferrer")] - if self.wait and self.listen_notify: + side_tasks = [ + asyncio.create_task(self._periodic_deferrer(), name="deferrer"), + asyncio.create_task(self._poll_jobs_to_abort(), name="poll_jobs_to_abort"), + ] + if self.listen_notify: listener_coro = self.app.job_manager.listen_for_jobs( - event=self._notify_event, + on_notification=self._handle_notification, queues=self.queues, ) side_tasks.append(asyncio.create_task(listener_coro, name="listener")) @@ -384,7 +424,7 @@ async def _run_loop(self): action="start_worker", context=None, queues=self.queues ), ) - self._notify_event.clear() + self._new_job_event.clear() self._stop_event.clear() self._running_jobs = {} self._job_semaphore = asyncio.Semaphore(self.concurrency) @@ -413,7 +453,7 @@ async def _run_loop(self): while not self._stop_event.is_set(): # wait for a new job notification, a stop even or the next polling interval await utils.wait_any( - self._notify_event.wait(), + self._new_job_event.wait(), asyncio.sleep(self.polling_interval), self._stop_event.wait(), ) diff --git a/tests/acceptance/test_async.py b/tests/acceptance/test_async.py index 70f60cfdf..ce4c9109a 100644 --- a/tests/acceptance/test_async.py +++ b/tests/acceptance/test_async.py @@ -108,47 +108,76 @@ def example_task(): assert len(jobs) == 1 -async def test_abort(async_app: app_module.App): +@pytest.mark.parametrize("mode", ["listen", "poll"]) +async def test_abort_async_task(async_app: app_module.App, mode): @async_app.task(queue="default", name="task1", pass_context=True) async def task1(context): while True: await asyncio.sleep(0.02) - if await context.should_abort_async(): + if context.should_abort(): raise JobAborted - @async_app.task(queue="default", name="task2", pass_context=True) - def task2(context): + job_id = await task1.defer_async() + + polling_interval = 0.1 + + worker_task = asyncio.create_task( + async_app.run_worker_async( + queues=["default"], + wait=False, + polling_interval=polling_interval, + listen_notify=True if mode == "listen" else False, + ) + ) + + await asyncio.sleep(0.05) + result = await async_app.job_manager.cancel_job_by_id_async(job_id, abort=True) + assert result is True + + # when listening for notifications, job should cancel within ms + # if notifications are disabled, job will only cancel after polling_interval + await asyncio.wait_for( + worker_task, timeout=0.1 if mode == "listen" else polling_interval * 2 + ) + + status = await async_app.job_manager.get_job_status_async(job_id) + assert status == Status.ABORTED + + +@pytest.mark.parametrize("mode", ["listen", "poll"]) +async def test_abort_sync_task(async_app: app_module.App, mode): + @async_app.task(queue="default", name="task1", pass_context=True) + def task1(context): while True: time.sleep(0.02) if context.should_abort(): raise JobAborted - job1_id = await task1.defer_async() - job2_id = await task2.defer_async() + job_id = await task1.defer_async() + + polling_interval = 0.1 worker_task = asyncio.create_task( - async_app.run_worker_async(queues=["default"], wait=False) + async_app.run_worker_async( + queues=["default"], + wait=False, + polling_interval=polling_interval, + listen_notify=True if mode == "listen" else False, + ) ) - await asyncio.sleep(0.1) - result = await async_app.job_manager.cancel_job_by_id_async(job1_id, abort=True) - assert result is True - - await asyncio.sleep(0.1) - result = await async_app.job_manager.cancel_job_by_id_async(job2_id, abort=True) + await asyncio.sleep(0.05) + result = await async_app.job_manager.cancel_job_by_id_async(job_id, abort=True) assert result is True - await worker_task - - status = await async_app.job_manager.get_job_status_async(job1_id) - assert status == Status.ABORTED - abort_requested = await async_app.job_manager.get_job_abort_requested_async(job1_id) - assert abort_requested is False + # when listening for notifications, job should cancel within ms + # if notifications are disabled, job will only cancel after polling_interval + await asyncio.wait_for( + worker_task, timeout=0.1 if mode == "listen" else polling_interval * 2 + ) - status = await async_app.job_manager.get_job_status_async(job2_id) + status = await async_app.job_manager.get_job_status_async(job_id) assert status == Status.ABORTED - abort_requested = await async_app.job_manager.get_job_abort_requested_async(job2_id) - assert abort_requested is False async def test_concurrency(async_app: app_module.App): diff --git a/tests/conftest.py b/tests/conftest.py index 9e190a1b6..e8bbf4086 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,7 +126,9 @@ def not_opened_sync_psycopg_connector(psycopg_connection_params): @pytest.fixture -async def psycopg_connector(not_opened_psycopg_connector): +async def psycopg_connector( + not_opened_psycopg_connector: async_psycopg_connector_module.PsycopgConnector, +): await not_opened_psycopg_connector.open_async() yield not_opened_psycopg_connector await not_opened_psycopg_connector.close_async() diff --git a/tests/integration/contrib/aiopg/conftest.py b/tests/integration/contrib/aiopg/conftest.py index 353b773c5..701b6379f 100644 --- a/tests/integration/contrib/aiopg/conftest.py +++ b/tests/integration/contrib/aiopg/conftest.py @@ -27,5 +27,5 @@ async def _(*, open: bool = True, **kwargs): @pytest.fixture -async def aiopg_connector(aiopg_connector_factory): +async def aiopg_connector(aiopg_connector_factory) -> aiopg.AiopgConnector: return await aiopg_connector_factory() diff --git a/tests/integration/contrib/aiopg/test_aiopg_connector.py b/tests/integration/contrib/aiopg/test_aiopg_connector.py index ad7adf745..b413bdb0b 100644 --- a/tests/integration/contrib/aiopg/test_aiopg_connector.py +++ b/tests/integration/contrib/aiopg/test_aiopg_connector.py @@ -156,15 +156,26 @@ async def test_get_connection_no_psycopg2_adapter_registration( async def test_listen_notify(aiopg_connector): channel = "somechannel" event = asyncio.Event() + received_args: list[dict] = [] + + async def handle_notification(*, channel: str, payload: str): + event.set() + received_args.append({"channel": channel, "payload": payload}) task = asyncio.ensure_future( - aiopg_connector.listen_notify(channels=[channel], event=event) + aiopg_connector.listen_notify( + channels=[channel], on_notification=handle_notification + ) ) try: - await event.wait() - event.clear() - await aiopg_connector.execute_query_async(f"""NOTIFY "{channel}" """) + await asyncio.sleep(0.1) + await aiopg_connector.execute_query_async( + f"""NOTIFY "{channel}", 'somepayload' """ + ) await asyncio.wait_for(event.wait(), timeout=1) + args = received_args.pop() + assert args["channel"] == "somechannel" + assert args["payload"] == "somepayload" except asyncio.TimeoutError: pytest.fail("Notify not received within 1 sec") finally: @@ -174,9 +185,15 @@ async def test_listen_notify(aiopg_connector): async def test_loop_notify_stop_when_connection_closed_old_aiopg(aiopg_connector): # We want to make sure that the when the connection is closed, the loop end. event = asyncio.Event() + + async def handle_notification(channel: str, payload: str): + event.set() + await aiopg_connector.open_async() async with aiopg_connector._pool.acquire() as connection: - coro = aiopg_connector._loop_notify(event=event, connection=connection) + coro = aiopg_connector._loop_notify( + on_notification=handle_notification, connection=connection + ) await asyncio.sleep(0.1) # Currently, the the connection closes, the notifies queue is not # awaken. This test validates the "normal" stopping condition, there is @@ -192,9 +209,15 @@ async def test_loop_notify_stop_when_connection_closed_old_aiopg(aiopg_connector async def test_loop_notify_stop_when_connection_closed(aiopg_connector): # We want to make sure that the when the connection is closed, the loop end. event = asyncio.Event() + + async def handle_notification(channel: str, payload: str): + event.set() + await aiopg_connector.open_async() async with aiopg_connector._pool.acquire() as connection: - coro = aiopg_connector._loop_notify(event=event, connection=connection) + coro = aiopg_connector._loop_notify( + on_notification=handle_notification, connection=connection + ) await asyncio.sleep(0.1) # Currently, the the connection closes, the notifies queue is not # awaken. This test validates the "normal" stopping condition, there is @@ -211,11 +234,15 @@ async def test_loop_notify_timeout(aiopg_connector): # We want to make sure that when the listen starts, we don't listen forever. If the # connection closes, we eventually finish the coroutine. event = asyncio.Event() + + async def handle_notification(channel: str, payload: str): + event.set() + await aiopg_connector.open_async() async with aiopg_connector._pool.acquire() as connection: task = asyncio.ensure_future( aiopg_connector._loop_notify( - event=event, connection=connection, timeout=0.01 + on_notification=handle_notification, connection=connection, timeout=0.01 ) ) await asyncio.sleep(0.1) @@ -234,6 +261,7 @@ async def test_destructor(connection_params, capsys): await connector.open_async() await connector.execute_query_async("SELECT 1") + assert connector._pool assert len(connector._pool._free) == 1 # "del connector" causes a ResourceWarning from aiopg.Pool if the diff --git a/tests/integration/contrib/django/test_models.py b/tests/integration/contrib/django/test_models.py index 66f434fd3..fcee72f46 100644 --- a/tests/integration/contrib/django/test_models.py +++ b/tests/integration/contrib/django/test_models.py @@ -44,6 +44,7 @@ def test_procrastinate_job__property(db): scheduled_at=datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc), attempts=0, queueing_lock="baz", + abort_requested=False, ) assert job.procrastinate_job == jobs_module.Job( id=1, diff --git a/tests/integration/test_psycopg_connector.py b/tests/integration/test_psycopg_connector.py index b36cbe8a2..9539e50ef 100644 --- a/tests/integration/test_psycopg_connector.py +++ b/tests/integration/test_psycopg_connector.py @@ -166,15 +166,26 @@ async def test_close_async(psycopg_connector): async def test_listen_notify(psycopg_connector): channel = "somechannel" event = asyncio.Event() + received_args: list[dict] = [] + + async def handle_notification(*, channel: str, payload: str): + event.set() + received_args.append({"channel": channel, "payload": payload}) task = asyncio.ensure_future( - psycopg_connector.listen_notify(channels=[channel], event=event) + psycopg_connector.listen_notify( + channels=[channel], on_notification=handle_notification + ) ) try: - await asyncio.wait_for(event.wait(), timeout=0.2) - event.clear() - await psycopg_connector.execute_query_async(f"""NOTIFY "{channel}" """) + await asyncio.sleep(0.1) + await psycopg_connector.execute_query_async( + f"""NOTIFY "{channel}", 'somepayload' """ + ) await asyncio.wait_for(event.wait(), timeout=1) + args = received_args.pop() + assert args["channel"] == "somechannel" + assert args["payload"] == "somepayload" except asyncio.TimeoutError: pytest.fail("Notify not received within 1 sec") finally: @@ -193,10 +204,14 @@ async def configure(connection): async def test_loop_notify_stop_when_connection_closed(psycopg_connector): # We want to make sure that the when the connection is closed, the loop end. - event = asyncio.Event() + async def handle_notification(channel: str, payload: str): + pass + await psycopg_connector.open_async() async with psycopg_connector._async_pool.connection() as connection: - coro = psycopg_connector._loop_notify(event=event, connection=connection) + coro = psycopg_connector._loop_notify( + on_notification=handle_notification, connection=connection + ) await psycopg_connector._async_pool.close() assert connection.closed @@ -210,11 +225,18 @@ async def test_loop_notify_stop_when_connection_closed(psycopg_connector): async def test_loop_notify_timeout(psycopg_connector): # We want to make sure that when the listen starts, we don't listen forever. If the # connection closes, we eventually finish the coroutine. + event = asyncio.Event() + + async def handle_notification(channel: str, payload: str): + event.set() + await psycopg_connector.open_async() async with psycopg_connector._async_pool.connection() as connection: task = asyncio.ensure_future( - psycopg_connector._loop_notify(event=event, connection=connection) + psycopg_connector._loop_notify( + on_notification=handle_notification, connection=connection + ) ) assert not task.done() diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 1aa11db3f..e60583876 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -97,7 +97,7 @@ async def my_task(a): task = asyncio.create_task(app.run_worker_async()) await my_task.defer_async(a=1) - await asyncio.sleep(0.01) + await asyncio.sleep(0.02) task.cancel() with pytest.raises(asyncio.CancelledError): await asyncio.wait_for(task, timeout=0.1) diff --git a/tests/unit/test_builtin_tasks.py b/tests/unit/test_builtin_tasks.py index d9d17de65..0d49b9f04 100644 --- a/tests/unit/test_builtin_tasks.py +++ b/tests/unit/test_builtin_tasks.py @@ -10,7 +10,7 @@ async def test_remove_old_jobs(app: App, job_factory): job = job_factory() await builtin_tasks.remove_old_jobs( - job_context.JobContext(app=app, job=job), + job_context.JobContext(app=app, job=job, should_abort=lambda: False), max_hours=2, queue="queue_a", remove_failed=True, diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index e561b6787..a63db018f 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -30,7 +30,7 @@ async def test_close_async(connector): ["execute_query_async", {"query": ""}], ["execute_query_one_async", {"query": ""}], ["execute_query_all_async", {"query": ""}], - ["listen_notify", {"event": None, "channels": []}], + ["listen_notify", {"on_notification": None, "channels": []}], ], ) async def test_missing_app_async(method_name, kwargs): diff --git a/tests/unit/test_job_context.py b/tests/unit/test_job_context.py index 6d514bf86..2b4e27bcd 100644 --- a/tests/unit/test_job_context.py +++ b/tests/unit/test_job_context.py @@ -47,15 +47,17 @@ def test_job_result_as_dict(job_result, expected, mocker): def test_evolve(app: App, job_factory): job = job_factory() - context = job_context.JobContext(app=app, job=job, worker_name="a") + context = job_context.JobContext( + app=app, job=job, worker_name="a", should_abort=lambda: False + ) assert context.evolve(worker_name="b").worker_name == "b" def test_job_description_job_no_time(app: App, job_factory): job = job_factory(task_name="some_task", id=12, task_kwargs={"a": "b"}) - descr = job_context.JobContext(worker_name="a", job=job, app=app).job_description( - current_timestamp=0 - ) + descr = job_context.JobContext( + worker_name="a", job=job, app=app, should_abort=lambda: False + ).job_description(current_timestamp=0) assert descr == "worker: some_task[12](a='b')" @@ -66,21 +68,6 @@ def test_job_description_job_time(app: App, job_factory): job=job, app=app, job_result=job_context.JobResult(start_timestamp=20.0), + should_abort=lambda: False, ).job_description(current_timestamp=30.0) assert descr == "worker: some_task[12](a='b') (started 10.000 s ago)" - - -async def test_should_abort(app, job_factory): - await app.job_manager.defer_job_async(job=job_factory()) - job = await app.job_manager.fetch_job(queues=None) - await app.job_manager.cancel_job_by_id_async(job.id, abort=True) - context = job_context.JobContext(app=app, job=job) - assert await context.should_abort_async() is True - - -async def test_should_not_abort(app, job_factory): - await app.job_manager.defer_job_async(job=job_factory()) - job = await app.job_manager.fetch_job(queues=None) - await app.job_manager.cancel_job_by_id_async(job.id) - context = job_context.JobContext(app=app, job=job) - assert await context.should_abort_async() is False diff --git a/tests/unit/test_jobs.py b/tests/unit/test_jobs.py index b87ea49e1..c321ffd22 100644 --- a/tests/unit/test_jobs.py +++ b/tests/unit/test_jobs.py @@ -42,6 +42,7 @@ def test_job_get_context(job_factory, scheduled_at, context_scheduled_at): "scheduled_at": context_scheduled_at, "attempts": 42, "call_string": "mytask[12](a='b')", + "abort_requested": False, } diff --git a/tests/unit/test_manager.py b/tests/unit/test_manager.py index ef8c8c32d..3a1e695a5 100644 --- a/tests/unit/test_manager.py +++ b/tests/unit/test_manager.py @@ -300,17 +300,17 @@ async def test_retry_job(job_manager, job_factory, connector): ], ) async def test_listen_for_jobs(job_manager, connector, mocker, queues, channels): - event = mocker.Mock() + on_notification = mocker.Mock() - await job_manager.listen_for_jobs(queues=queues, event=event) - assert connector.notify_event is event + await job_manager.listen_for_jobs(queues=queues, on_notification=on_notification) + assert connector.on_notification assert connector.notify_channels == channels @pytest.fixture def configure(app): @app.task - def foo(timestamp): + def foo(timestamp: int): pass return foo.configure diff --git a/tests/unit/test_testing.py b/tests/unit/test_testing.py index 537318437..7196963a7 100644 --- a/tests/unit/test_testing.py +++ b/tests/unit/test_testing.py @@ -417,8 +417,15 @@ async def test_listen_for_jobs_run(connector): async def test_defer_no_notify(connector): # This test is there to check that if the deferred queue doesn't match the # listened queue, the testing connector doesn't notify. + event = asyncio.Event() - await connector.listen_notify(event=event, channels="some_other_channel") + + async def on_notification(*, channel: str, payload: str): + event.set() + + await connector.listen_notify( + on_notification=on_notification, channels="some_other_channel" + ) await connector.defer_job_one( task_name="foo", priority=0, diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py index 24815ce58..72a6101b3 100644 --- a/tests/unit/test_worker.py +++ b/tests/unit/test_worker.py @@ -97,7 +97,6 @@ async def test_worker_run_wait_listen(worker): await start_worker(worker) connector = cast(InMemoryConnector, worker.app.connector) - assert connector.notify_event assert connector.notify_channels == ["procrastinate_any_queue"] @@ -166,8 +165,6 @@ async def perform_job(sleep: float): await perform_job.defer_async(sleep=0.05) await perform_job.defer_async(sleep=0.05) - worker._notify_event.set() - await asyncio.sleep(0.2) assert max_parallelism == 2 assert parallel_jobs == 0 @@ -533,6 +530,37 @@ async def task_func(): assert "Aborted" in record.message +@pytest.mark.parametrize( + "worker", + [ + ({"listen_notify": False, "polling_interval": 0.05}), + ({"listen_notify": True, "polling_interval": 1}), + ], + indirect=["worker"], +) +async def test_run_job_abort(app: App, worker: Worker): + @app.task(queue="yay", name="task_func", pass_context=True) + async def task_func(job_context: JobContext): + while True: + await asyncio.sleep(0.01) + if job_context.should_abort(): + raise JobAborted() + + job_id = await task_func.defer_async() + + await start_worker(worker) + + await app.job_manager.cancel_job_by_id_async(job_id, abort=True) + + await asyncio.sleep(0.01 if worker.listen_notify else 0.05) + + status = await app.job_manager.get_job_status_async(job_id) + assert status == Status.ABORTED + assert ( + worker._job_ids_to_abort == set() + ), "Expected cancelled job id to be removed from set" + + @pytest.mark.parametrize( "critical_error, recover_on_attempt_number, expected_status, expected_attempts", [ @@ -596,7 +624,7 @@ def t(): "fetch_job", ] - logs = {(r.action, r.levelname) for r in caplog.records} + logs = {(r.action, r.levelname) for r in caplog.records if hasattr(r, "action")} # remove the periodic_deferrer_no_task log record because that makes the test flaky assert { ("about_to_defer_job", "DEBUG"), @@ -633,13 +661,6 @@ async def t(): ) -async def test_run_no_listen_notify(app: App, worker): - worker.listen_notify = False - await start_worker(worker) - connector = cast(InMemoryConnector, app.connector) - assert connector.notify_event is None - - async def test_run_no_signal_handlers(worker, kill_own_pid): worker.install_signal_handlers = False await start_worker(worker) diff --git a/tests/unit/test_worker_sync.py b/tests/unit/test_worker_sync.py index 7aaadce72..7a916fb50 100644 --- a/tests/unit/test_worker_sync.py +++ b/tests/unit/test_worker_sync.py @@ -2,7 +2,7 @@ import pytest -from procrastinate import exceptions, job_context, worker +from procrastinate import exceptions, worker from procrastinate.app import App @@ -11,11 +11,6 @@ def test_worker(app: App) -> worker.Worker: return worker.Worker(app=app, queues=["yay"]) -@pytest.fixture -def context(app: App, job_factory): - return job_context.JobContext(app=app, job=job_factory()) - - def test_worker_find_task_missing(test_worker): with pytest.raises(exceptions.TaskNotFound): test_worker.find_task("foobarbaz")